summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorJason Robinson <jasonr@matrix.org>2021-01-23 21:41:35 +0200
committerJason Robinson <jasonr@matrix.org>2021-01-23 21:41:35 +0200
commit8965b6cfec8a1de847efe3d1be4b9babf4622e2e (patch)
tree4551f104ee2ce840689aa5ecffa939938482ffd5 /synapse
parentAdd depth and received_ts to forward_extremities admin API response (diff)
parentReturn a 404 if no valid thumbnail is found. (#9163) (diff)
downloadsynapse-8965b6cfec8a1de847efe3d1be4b9babf4622e2e.tar.xz
Merge branch 'develop' into jaywink/admin-forward-extremities
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py7
-rw-r--r--synapse/api/urls.py2
-rw-r--r--synapse/app/_base.py150
-rw-r--r--synapse/app/generic_worker.py34
-rw-r--r--synapse/app/homeserver.py73
-rw-r--r--synapse/config/_base.py11
-rw-r--r--synapse/config/cas.py2
-rw-r--r--synapse/config/emailconfig.py8
-rw-r--r--synapse/config/oidc_config.py594
-rw-r--r--synapse/config/registration.py32
-rw-r--r--synapse/config/saml2_config.py2
-rw-r--r--synapse/config/server.py26
-rw-r--r--synapse/config/sso.py23
-rw-r--r--synapse/config/workers.py18
-rw-r--r--synapse/events/__init__.py3
-rw-r--r--synapse/federation/federation_client.py125
-rw-r--r--synapse/federation/federation_server.py2
-rw-r--r--synapse/federation/transport/server.py2
-rw-r--r--synapse/handlers/account_data.py144
-rw-r--r--synapse/handlers/auth.py120
-rw-r--r--synapse/handlers/cas_handler.py4
-rw-r--r--synapse/handlers/deactivate_account.py18
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/oidc_handler.py599
-rw-r--r--synapse/handlers/profile.py8
-rw-r--r--synapse/handlers/read_marker.py5
-rw-r--r--synapse/handlers/receipts.py27
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_member.py7
-rw-r--r--synapse/handlers/saml_handler.py4
-rw-r--r--synapse/handlers/sso.py86
-rw-r--r--synapse/handlers/ui_auth/__init__.py15
-rw-r--r--synapse/http/__init__.py15
-rw-r--r--synapse/http/client.py25
-rw-r--r--synapse/http/endpoint.py79
-rw-r--r--synapse/http/federation/matrix_federation_agent.py1
-rw-r--r--synapse/http/matrixfederationclient.py12
-rw-r--r--synapse/http/proxyagent.py16
-rw-r--r--synapse/http/site.py18
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/_base.py2
-rw-r--r--synapse/replication/http/account_data.py187
-rw-r--r--synapse/replication/slave/storage/_base.py10
-rw-r--r--synapse/replication/slave/storage/account_data.py40
-rw-r--r--synapse/replication/slave/storage/receipts.py35
-rw-r--r--synapse/replication/tcp/handler.py19
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html18
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html3
-rw-r--r--synapse/rest/admin/media.py64
-rw-r--r--synapse/rest/admin/users.py50
-rw-r--r--synapse/rest/client/v1/room.py20
-rw-r--r--synapse/rest/client/v2_alpha/account.py44
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py22
-rw-r--r--synapse/rest/client/v2_alpha/auth.py37
-rw-r--r--synapse/rest/client/v2_alpha/devices.py12
-rw-r--r--synapse/rest/client/v2_alpha/keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/register.py21
-rw-r--r--synapse/rest/client/v2_alpha/tags.py11
-rw-r--r--synapse/rest/media/v1/_base.py79
-rw-r--r--synapse/rest/media/v1/config_resource.py14
-rw-r--r--synapse/rest/media/v1/download_resource.py18
-rw-r--r--synapse/rest/media/v1/filepath.py50
-rw-r--r--synapse/rest/media/v1/media_repository.py50
-rw-r--r--synapse/rest/media/v1/media_storage.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py77
-rw-r--r--synapse/rest/media/v1/storage_provider.py37
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py313
-rw-r--r--synapse/rest/media/v1/thumbnailer.py18
-rw-r--r--synapse/rest/media/v1/upload_resource.py14
-rw-r--r--synapse/rest/synapse/client/pick_idp.py4
-rw-r--r--synapse/rest/well_known.py4
-rw-r--r--synapse/server.py11
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py3
-rw-r--r--synapse/server_notices/server_notices_manager.py3
-rw-r--r--synapse/storage/database.py44
-rw-r--r--synapse/storage/databases/main/__init__.py10
-rw-r--r--synapse/storage/databases/main/account_data.py147
-rw-r--r--synapse/storage/databases/main/client_ips.py41
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py129
-rw-r--r--synapse/storage/databases/main/event_federation.py185
-rw-r--r--synapse/storage/databases/main/event_push_actions.py96
-rw-r--r--synapse/storage/databases/main/events.py703
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py360
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/media_repository.py13
-rw-r--r--synapse/storage/databases/main/profile.py2
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/pusher.py5
-rw-r--r--synapse/storage/databases/main/receipts.py108
-rw-r--r--synapse/storage/databases/main/registration.py2
-rw-r--r--synapse/storage/databases/main/room.py57
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite62
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql52
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres32
-rw-r--r--synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql18
-rw-r--r--synapse/storage/databases/main/search.py4
-rw-r--r--synapse/storage/databases/main/tags.py20
-rw-r--r--synapse/storage/databases/main/transactions.py24
-rw-r--r--synapse/storage/databases/state/store.py4
-rw-r--r--synapse/storage/prepare_database.py3
-rw-r--r--synapse/storage/util/id_generators.py109
-rw-r--r--synapse/storage/util/sequence.py82
-rw-r--r--synapse/types.py8
-rw-r--r--synapse/util/iterutils.py53
-rw-r--r--synapse/util/stringutils.py111
118 files changed, 4528 insertions, 1737 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 99fb675748..d423856d82 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.25.0rc1"
+__version__ = "1.26.0rc1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 48c4d7b0be..67ecbd32ff 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
+from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
@@ -186,8 +187,8 @@ class Auth:
             AuthError if access is denied for the user in the access token
         """
         try:
-            ip_addr = self.hs.get_ip_from_request(request)
-            user_agent = request.get_user_agent("")
+            ip_addr = request.getClientIP()
+            user_agent = get_request_user_agent(request)
 
             access_token = self.get_access_token_from_request(request)
 
@@ -275,7 +276,7 @@ class Auth:
             return None, None
 
         if app_service.ip_range_whitelist:
-            ip_address = IPAddress(self.hs.get_ip_from_request(request))
+            ip_address = IPAddress(request.getClientIP())
             if ip_address not in app_service.ip_range_whitelist:
                 return None, None
 
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 6379c86dde..e36aeef31f 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -42,8 +42,6 @@ class ConsentURIBuilder:
         """
         if hs_config.form_secret is None:
             raise ConfigError("form_secret not set in config")
-        if hs_config.public_baseurl is None:
-            raise ConfigError("public_baseurl not set in config")
 
         self._hmac_secret = hs_config.form_secret.encode("utf-8")
         self._public_baseurl = hs_config.public_baseurl
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 37ecdbe3d8..395e202b89 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2017 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,7 +20,7 @@ import signal
 import socket
 import sys
 import traceback
-from typing import Iterable
+from typing import Awaitable, Callable, Iterable
 
 from typing_extensions import NoReturn
 
@@ -143,6 +144,45 @@ def quit_with_error(error_string: str) -> NoReturn:
     sys.exit(1)
 
 
+def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
+    """Register a callback with the reactor, to be called once it is running
+
+    This can be used to initialise parts of the system which require an asynchronous
+    setup.
+
+    Any exception raised by the callback will be printed and logged, and the process
+    will exit.
+    """
+
+    async def wrapper():
+        try:
+            await cb(*args, **kwargs)
+        except Exception:
+            # previously, we used Failure().printTraceback() here, in the hope that
+            # would give better tracebacks than traceback.print_exc(). However, that
+            # doesn't handle chained exceptions (with a __cause__ or __context__) well,
+            # and I *think* the need for Failure() is reduced now that we mostly use
+            # async/await.
+
+            # Write the exception to both the logs *and* the unredirected stderr,
+            # because people tend to get confused if it only goes to one or the other.
+            #
+            # One problem with this is that if people are using a logging config that
+            # logs to the console (as is common eg under docker), they will get two
+            # copies of the exception. We could maybe try to detect that, but it's
+            # probably a cost we can bear.
+            logger.fatal("Error during startup", exc_info=True)
+            print("Error during startup:", file=sys.__stderr__)
+            traceback.print_exc(file=sys.__stderr__)
+
+            # it's no use calling sys.exit here, since that just raises a SystemExit
+            # exception which is then caught by the reactor, and everything carries
+            # on as normal.
+            os._exit(1)
+
+    reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
+
+
 def listen_metrics(bind_addresses, port):
     """
     Start Prometheus metrics server.
@@ -227,7 +267,7 @@ def refresh_certificate(hs):
         logger.info("Context factories updated.")
 
 
-def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
+async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
     """
     Start a Synapse server or worker.
 
@@ -241,75 +281,67 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
         hs: homeserver instance
         listeners: Listener configuration ('listeners' in homeserver.yaml)
     """
-    try:
-        # Set up the SIGHUP machinery.
-        if hasattr(signal, "SIGHUP"):
+    # Set up the SIGHUP machinery.
+    if hasattr(signal, "SIGHUP"):
+        reactor = hs.get_reactor()
 
-            reactor = hs.get_reactor()
+        @wrap_as_background_process("sighup")
+        def handle_sighup(*args, **kwargs):
+            # Tell systemd our state, if we're using it. This will silently fail if
+            # we're not using systemd.
+            sdnotify(b"RELOADING=1")
 
-            @wrap_as_background_process("sighup")
-            def handle_sighup(*args, **kwargs):
-                # Tell systemd our state, if we're using it. This will silently fail if
-                # we're not using systemd.
-                sdnotify(b"RELOADING=1")
+            for i, args, kwargs in _sighup_callbacks:
+                i(*args, **kwargs)
 
-                for i, args, kwargs in _sighup_callbacks:
-                    i(*args, **kwargs)
+            sdnotify(b"READY=1")
 
-                sdnotify(b"READY=1")
+        # We defer running the sighup handlers until next reactor tick. This
+        # is so that we're in a sane state, e.g. flushing the logs may fail
+        # if the sighup happens in the middle of writing a log entry.
+        def run_sighup(*args, **kwargs):
+            # `callFromThread` should be "signal safe" as well as thread
+            # safe.
+            reactor.callFromThread(handle_sighup, *args, **kwargs)
 
-            # We defer running the sighup handlers until next reactor tick. This
-            # is so that we're in a sane state, e.g. flushing the logs may fail
-            # if the sighup happens in the middle of writing a log entry.
-            def run_sighup(*args, **kwargs):
-                # `callFromThread` should be "signal safe" as well as thread
-                # safe.
-                reactor.callFromThread(handle_sighup, *args, **kwargs)
+        signal.signal(signal.SIGHUP, run_sighup)
 
-            signal.signal(signal.SIGHUP, run_sighup)
+        register_sighup(refresh_certificate, hs)
 
-            register_sighup(refresh_certificate, hs)
+    # Load the certificate from disk.
+    refresh_certificate(hs)
 
-        # Load the certificate from disk.
-        refresh_certificate(hs)
+    # Start the tracer
+    synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
+        hs
+    )
 
-        # Start the tracer
-        synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
-            hs
-        )
+    # It is now safe to start your Synapse.
+    hs.start_listening(listeners)
+    hs.get_datastore().db_pool.start_profiling()
+    hs.get_pusherpool().start()
+
+    # Log when we start the shut down process.
+    hs.get_reactor().addSystemEventTrigger(
+        "before", "shutdown", logger.info, "Shutting down..."
+    )
 
-        # It is now safe to start your Synapse.
-        hs.start_listening(listeners)
-        hs.get_datastore().db_pool.start_profiling()
-        hs.get_pusherpool().start()
+    setup_sentry(hs)
+    setup_sdnotify(hs)
 
-        # Log when we start the shut down process.
-        hs.get_reactor().addSystemEventTrigger(
-            "before", "shutdown", logger.info, "Shutting down..."
-        )
+    # If background tasks are running on the main process, start collecting the
+    # phone home stats.
+    if hs.config.run_background_tasks:
+        start_phone_stats_home(hs)
 
-        setup_sentry(hs)
-        setup_sdnotify(hs)
-
-        # If background tasks are running on the main process, start collecting the
-        # phone home stats.
-        if hs.config.run_background_tasks:
-            start_phone_stats_home(hs)
-
-        # We now freeze all allocated objects in the hopes that (almost)
-        # everything currently allocated are things that will be used for the
-        # rest of time. Doing so means less work each GC (hopefully).
-        #
-        # This only works on Python 3.7
-        if sys.version_info >= (3, 7):
-            gc.collect()
-            gc.freeze()
-    except Exception:
-        traceback.print_exc(file=sys.stderr)
-        reactor = hs.get_reactor()
-        if reactor.running:
-            reactor.stop()
-        sys.exit(1)
+    # We now freeze all allocated objects in the hopes that (almost)
+    # everything currently allocated are things that will be used for the
+    # rest of time. Doing so means less work each GC (hopefully).
+    #
+    # This only works on Python 3.7
+    if sys.version_info >= (3, 7):
+        gc.collect()
+        gc.freeze()
 
 
 def setup_sentry(hs):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 4428472707..e60988fa4a 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
 
 from typing_extensions import ContextManager
 
-from twisted.internet import address, reactor
+from twisted.internet import address
 
 import synapse
 import synapse.events
@@ -34,6 +34,7 @@ from synapse.api.urls import (
     SERVER_KEY_V2_PREFIX,
 )
 from synapse.app import _base
+from synapse.app._base import register_start
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
@@ -99,14 +100,28 @@ from synapse.rest.client.v1.profile import (
 )
 from synapse.rest.client.v1.push_rule import PushRuleRestServlet
 from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, sync, user_directory
+from synapse.rest.client.v2_alpha import (
+    account_data,
+    groups,
+    read_marker,
+    receipts,
+    room_keys,
+    sync,
+    tags,
+    user_directory,
+)
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
 from synapse.rest.client.v2_alpha.account_data import (
     AccountDataServlet,
     RoomAccountDataServlet,
 )
-from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
+from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
+from synapse.rest.client.v2_alpha.keys import (
+    KeyChangesServlet,
+    KeyQueryServlet,
+    OneTimeKeyServlet,
+)
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
 from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
@@ -115,6 +130,7 @@ from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.server import HomeServer, cache_in_self
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
 from synapse.storage.databases.main.metrics import ServerMetricsStore
 from synapse.storage.databases.main.monthly_active_users import (
@@ -446,6 +462,7 @@ class GenericWorkerSlavedStore(
     UserDirectoryStore,
     StatsStore,
     UIAuthWorkerStore,
+    EndToEndRoomKeyStore,
     SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedReceiptsStore,
@@ -502,7 +519,9 @@ class GenericWorkerServer(HomeServer):
                     RegisterRestServlet(self).register(resource)
                     LoginRestServlet(self).register(resource)
                     ThreepidRestServlet(self).register(resource)
+                    DevicesRestServlet(self).register(resource)
                     KeyQueryServlet(self).register(resource)
+                    OneTimeKeyServlet(self).register(resource)
                     KeyChangesServlet(self).register(resource)
                     VoipRestServlet(self).register(resource)
                     PushRuleRestServlet(self).register(resource)
@@ -520,6 +539,11 @@ class GenericWorkerServer(HomeServer):
                     room.register_servlets(self, resource, True)
                     room.register_deprecated_servlets(self, resource)
                     InitialSyncRestServlet(self).register(resource)
+                    room_keys.register_servlets(self, resource)
+                    tags.register_servlets(self, resource)
+                    account_data.register_servlets(self, resource)
+                    receipts.register_servlets(self, resource)
+                    read_marker.register_servlets(self, resource)
 
                     SendToDeviceRestServlet(self).register(resource)
 
@@ -960,9 +984,7 @@ def start(config_options):
     # streams. Will no-op if no streams can be written to by this worker.
     hs.get_replication_streamer()
 
-    reactor.addSystemEventTrigger(
-        "before", "startup", _base.start, hs, config.worker_listeners
-    )
+    register_start(_base.start, hs, config.worker_listeners)
 
     _base.start_worker_reactor("synapse-generic-worker", config)
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index b1d9817a6a..57a2f5237c 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -15,15 +15,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import gc
 import logging
 import os
 import sys
 from typing import Iterable, Iterator
 
-from twisted.application import service
-from twisted.internet import defer, reactor
-from twisted.python.failure import Failure
+from twisted.internet import reactor
 from twisted.web.resource import EncodingResourceWrapper, IResource
 from twisted.web.server import GzipEncoderFactory
 from twisted.web.static import File
@@ -40,7 +37,7 @@ from synapse.api.urls import (
     WEB_CLIENT_PREFIX,
 )
 from synapse.app import _base
-from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
+from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
 from synapse.config._base import ConfigError
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
@@ -73,7 +70,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.module_loader import load_module
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse.app.homeserver")
@@ -417,40 +413,28 @@ def setup(config_options):
             _base.refresh_certificate(hs)
 
     async def start():
-        try:
-            # Run the ACME provisioning code, if it's enabled.
-            if hs.config.acme_enabled:
-                acme = hs.get_acme_handler()
-                # Start up the webservices which we will respond to ACME
-                # challenges with, and then provision.
-                await acme.start_listening()
-                await do_acme()
+        # Run the ACME provisioning code, if it's enabled.
+        if hs.config.acme_enabled:
+            acme = hs.get_acme_handler()
+            # Start up the webservices which we will respond to ACME
+            # challenges with, and then provision.
+            await acme.start_listening()
+            await do_acme()
 
-                # Check if it needs to be reprovisioned every day.
-                hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
+            # Check if it needs to be reprovisioned every day.
+            hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
 
-            # Load the OIDC provider metadatas, if OIDC is enabled.
-            if hs.config.oidc_enabled:
-                oidc = hs.get_oidc_handler()
-                # Loading the provider metadata also ensures the provider config is valid.
-                await oidc.load_metadata()
-                await oidc.load_jwks()
+        # Load the OIDC provider metadatas, if OIDC is enabled.
+        if hs.config.oidc_enabled:
+            oidc = hs.get_oidc_handler()
+            # Loading the provider metadata also ensures the provider config is valid.
+            await oidc.load_metadata()
 
-            _base.start(hs, config.listeners)
+        await _base.start(hs, config.listeners)
 
-            hs.get_datastore().db_pool.updates.start_doing_background_updates()
-        except Exception:
-            # Print the exception and bail out.
-            print("Error during startup:", file=sys.stderr)
+        hs.get_datastore().db_pool.updates.start_doing_background_updates()
 
-            # this gives better tracebacks than traceback.print_exc()
-            Failure().printTraceback(file=sys.stderr)
-
-            if reactor.running:
-                reactor.stop()
-            sys.exit(1)
-
-    reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
+    register_start(start)
 
     return hs
 
@@ -487,25 +471,6 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
         e = e.__cause__
 
 
-class SynapseService(service.Service):
-    """
-    A twisted Service class that will start synapse. Used to run synapse
-    via twistd and a .tac.
-    """
-
-    def __init__(self, config):
-        self.config = config
-
-    def startService(self):
-        hs = setup(self.config)
-        change_resource_limit(hs.config.soft_file_limit)
-        if hs.config.gc_thresholds:
-            gc.set_threshold(*hs.config.gc_thresholds)
-
-    def stopService(self):
-        return self._port.stopListening()
-
-
 def run(hs):
     PROFILE_SYNAPSE = False
     if PROFILE_SYNAPSE:
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 2931a88207..94144efc87 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -252,11 +252,12 @@ class Config:
         env = jinja2.Environment(loader=loader, autoescape=autoescape)
 
         # Update the environment with our custom filters
-        env.filters.update({"format_ts": _format_ts_filter})
-        if self.public_baseurl:
-            env.filters.update(
-                {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)}
-            )
+        env.filters.update(
+            {
+                "format_ts": _format_ts_filter,
+                "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+            }
+        )
 
         for filename in filenames:
             # Load the template
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 2f97e6d258..c7877b4095 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -40,7 +40,7 @@ class CasConfig(Config):
             self.cas_required_attributes = {}
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
-        return """
+        return """\
         # Enable Central Authentication Service (CAS) for registration and login.
         #
         cas_config:
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index d4328c46b9..6a487afd34 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -166,11 +166,6 @@ class EmailConfig(Config):
             if not self.email_notif_from:
                 missing.append("email.notif_from")
 
-            # public_baseurl is required to build password reset and validation links that
-            # will be emailed to users
-            if config.get("public_baseurl") is None:
-                missing.append("public_baseurl")
-
             if missing:
                 raise ConfigError(
                     MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
@@ -269,9 +264,6 @@ class EmailConfig(Config):
             if not self.email_notif_from:
                 missing.append("email.notif_from")
 
-            if config.get("public_baseurl") is None:
-                missing.append("public_baseurl")
-
             if missing:
                 raise ConfigError(
                     "email.enable_notifs is True but required keys are missing: %s"
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 4e3055282d..bfeceeed18 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2020 Quentin Gliech
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,8 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import string
+from collections import Counter
+from typing import Iterable, Optional, Tuple, Type
+
+import attr
+
+from synapse.config._util import validate_config
 from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.types import Collection, JsonDict
 from synapse.util.module_loader import load_module
+from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 from ._base import Config, ConfigError
 
@@ -25,202 +35,442 @@ class OIDCConfig(Config):
     section = "oidc"
 
     def read_config(self, config, **kwargs):
-        self.oidc_enabled = False
-
-        oidc_config = config.get("oidc_config")
-
-        if not oidc_config or not oidc_config.get("enabled", False):
+        self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
+        if not self.oidc_providers:
             return
 
         try:
             check_requirements("oidc")
         except DependencyException as e:
-            raise ConfigError(e.message)
+            raise ConfigError(e.message) from e
+
+        # check we don't have any duplicate idp_ids now. (The SSO handler will also
+        # check for duplicates when the REST listeners get registered, but that happens
+        # after synapse has forked so doesn't give nice errors.)
+        c = Counter([i.idp_id for i in self.oidc_providers])
+        for idp_id, count in c.items():
+            if count > 1:
+                raise ConfigError(
+                    "Multiple OIDC providers have the idp_id %r." % idp_id
+                )
 
         public_baseurl = self.public_baseurl
-        if public_baseurl is None:
-            raise ConfigError("oidc_config requires a public_baseurl to be set")
         self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
 
-        self.oidc_enabled = True
-        self.oidc_discover = oidc_config.get("discover", True)
-        self.oidc_issuer = oidc_config["issuer"]
-        self.oidc_client_id = oidc_config["client_id"]
-        self.oidc_client_secret = oidc_config["client_secret"]
-        self.oidc_client_auth_method = oidc_config.get(
-            "client_auth_method", "client_secret_basic"
-        )
-        self.oidc_scopes = oidc_config.get("scopes", ["openid"])
-        self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
-        self.oidc_token_endpoint = oidc_config.get("token_endpoint")
-        self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
-        self.oidc_jwks_uri = oidc_config.get("jwks_uri")
-        self.oidc_skip_verification = oidc_config.get("skip_verification", False)
-        self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
-        self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
-
-        ump_config = oidc_config.get("user_mapping_provider", {})
-        ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
-        ump_config.setdefault("config", {})
-
-        (
-            self.oidc_user_mapping_provider_class,
-            self.oidc_user_mapping_provider_config,
-        ) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
-
-        # Ensure loaded user mapping module has defined all necessary methods
-        required_methods = [
-            "get_remote_user_id",
-            "map_user_attributes",
-        ]
-        missing_methods = [
-            method
-            for method in required_methods
-            if not hasattr(self.oidc_user_mapping_provider_class, method)
-        ]
-        if missing_methods:
-            raise ConfigError(
-                "Class specified by oidc_config."
-                "user_mapping_provider.module is missing required "
-                "methods: %s" % (", ".join(missing_methods),)
-            )
+    @property
+    def oidc_enabled(self) -> bool:
+        # OIDC is enabled if we have a provider
+        return bool(self.oidc_providers)
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
-        # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
+        # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
+        # and login.
+        #
+        # Options for each entry include:
+        #
+        #   idp_id: a unique identifier for this identity provider. Used internally
+        #       by Synapse; should be a single word such as 'github'.
+        #
+        #       Note that, if this is changed, users authenticating via that provider
+        #       will no longer be recognised as the same user!
+        #
+        #   idp_name: A user-facing name for this identity provider, which is used to
+        #       offer the user a choice of login mechanisms.
+        #
+        #   idp_icon: An optional icon for this identity provider, which is presented
+        #       by identity picker pages. If given, must be an MXC URI of the format
+        #       mxc://<server-name>/<media-id>. (An easy way to obtain such an MXC URI
+        #       is to upload an image to an (unencrypted) room and then copy the "url"
+        #       from the source of the event.)
+        #
+        #   discover: set to 'false' to disable the use of the OIDC discovery mechanism
+        #       to discover endpoints. Defaults to true.
+        #
+        #   issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
+        #       is enabled) to discover the provider's endpoints.
+        #
+        #   client_id: Required. oauth2 client id to use.
+        #
+        #   client_secret: Required. oauth2 client secret to use.
+        #
+        #   client_auth_method: auth method to use when exchanging the token. Valid
+        #       values are 'client_secret_basic' (default), 'client_secret_post' and
+        #       'none'.
+        #
+        #   scopes: list of scopes to request. This should normally include the "openid"
+        #       scope. Defaults to ["openid"].
+        #
+        #   authorization_endpoint: the oauth2 authorization endpoint. Required if
+        #       provider discovery is disabled.
+        #
+        #   token_endpoint: the oauth2 token endpoint. Required if provider discovery is
+        #       disabled.
+        #
+        #   userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
+        #       disabled and the 'openid' scope is not requested.
+        #
+        #   jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
+        #       the 'openid' scope is used.
+        #
+        #   skip_verification: set to 'true' to skip metadata verification. Use this if
+        #       you are connecting to a provider that is not OpenID Connect compliant.
+        #       Defaults to false. Avoid this in production.
+        #
+        #   user_profile_method: Whether to fetch the user profile from the userinfo
+        #       endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+        #
+        #       Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
+        #       included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+        #       userinfo endpoint.
+        #
+        #   allow_existing_users: set to 'true' to allow a user logging in via OIDC to
+        #       match a pre-existing account instead of failing. This could be used if
+        #       switching from password logins to OIDC. Defaults to false.
+        #
+        #   user_mapping_provider: Configuration for how attributes returned from a OIDC
+        #       provider are mapped onto a matrix user. This setting has the following
+        #       sub-properties:
+        #
+        #       module: The class name of a custom mapping module. Default is
+        #           {mapping_provider!r}.
+        #           See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+        #           for information on implementing a custom mapping provider.
+        #
+        #       config: Configuration for the mapping provider module. This section will
+        #           be passed as a Python dictionary to the user mapping provider
+        #           module's `parse_config` method.
+        #
+        #           For the default provider, the following settings are available:
+        #
+        #             sub: name of the claim containing a unique identifier for the
+        #                 user. Defaults to 'sub', which OpenID Connect compliant
+        #                 providers should provide.
+        #
+        #             localpart_template: Jinja2 template for the localpart of the MXID.
+        #                 If this is not set, the user will be prompted to choose their
+        #                 own username.
+        #
+        #             display_name_template: Jinja2 template for the display name to set
+        #                 on first login. If unset, no displayname will be set.
+        #
+        #             extra_attributes: a map of Jinja2 templates for extra attributes
+        #                 to send back to the client during login.
+        #                 Note that these are non-standard and clients will ignore them
+        #                 without modifications.
+        #
+        #           When rendering, the Jinja2 templates are given a 'user' variable,
+        #           which is set to the claims returned by the UserInfo Endpoint and/or
+        #           in the ID Token.
         #
         # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
-        # for some example configurations.
+        # for information on how to configure these options.
         #
-        oidc_config:
-          # Uncomment the following to enable authorization against an OpenID Connect
-          # server. Defaults to false.
-          #
-          #enabled: true
-
-          # Uncomment the following to disable use of the OIDC discovery mechanism to
-          # discover endpoints. Defaults to true.
-          #
-          #discover: false
-
-          # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
-          # discover the provider's endpoints.
-          #
-          # Required if 'enabled' is true.
-          #
-          #issuer: "https://accounts.example.com/"
-
-          # oauth2 client id to use.
-          #
-          # Required if 'enabled' is true.
-          #
-          #client_id: "provided-by-your-issuer"
-
-          # oauth2 client secret to use.
+        # For backwards compatibility, it is also possible to configure a single OIDC
+        # provider via an 'oidc_config' setting. This is now deprecated and admins are
+        # advised to migrate to the 'oidc_providers' format. (When doing that migration,
+        # use 'oidc' for the idp_id to ensure that existing users continue to be
+        # recognised.)
+        #
+        oidc_providers:
+          # Generic example
           #
-          # Required if 'enabled' is true.
+          #- idp_id: my_idp
+          #  idp_name: "My OpenID provider"
+          #  idp_icon: "mxc://example.com/mediaid"
+          #  discover: false
+          #  issuer: "https://accounts.example.com/"
+          #  client_id: "provided-by-your-issuer"
+          #  client_secret: "provided-by-your-issuer"
+          #  client_auth_method: client_secret_post
+          #  scopes: ["openid", "profile"]
+          #  authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+          #  token_endpoint: "https://accounts.example.com/oauth2/token"
+          #  userinfo_endpoint: "https://accounts.example.com/userinfo"
+          #  jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+          #  skip_verification: true
+
+          # For use with Keycloak
           #
-          #client_secret: "provided-by-your-issuer"
-
-          # auth method to use when exchanging the token.
-          # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
-          # 'none'.
+          #- idp_id: keycloak
+          #  idp_name: Keycloak
+          #  issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
+          #  client_id: "synapse"
+          #  client_secret: "copy secret generated in Keycloak UI"
+          #  scopes: ["openid", "profile"]
+
+          # For use with Github
           #
-          #client_auth_method: client_secret_post
+          #- idp_id: github
+          #  idp_name: Github
+          #  discover: false
+          #  issuer: "https://github.com/"
+          #  client_id: "your-client-id" # TO BE FILLED
+          #  client_secret: "your-client-secret" # TO BE FILLED
+          #  authorization_endpoint: "https://github.com/login/oauth/authorize"
+          #  token_endpoint: "https://github.com/login/oauth/access_token"
+          #  userinfo_endpoint: "https://api.github.com/user"
+          #  scopes: ["read:user"]
+          #  user_mapping_provider:
+          #    config:
+          #      subject_claim: "id"
+          #      localpart_template: "{{ user.login }}"
+          #      display_name_template: "{{ user.name }}"
+        """.format(
+            mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
+        )
 
-          # list of scopes to request. This should normally include the "openid" scope.
-          # Defaults to ["openid"].
-          #
-          #scopes: ["openid", "profile"]
 
-          # the oauth2 authorization endpoint. Required if provider discovery is disabled.
-          #
-          #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+# jsonschema definition of the configuration settings for an oidc identity provider
+OIDC_PROVIDER_CONFIG_SCHEMA = {
+    "type": "object",
+    "required": ["issuer", "client_id", "client_secret"],
+    "properties": {
+        # TODO: fix the maxLength here depending on what MSC2528 decides
+        #   remember that we prefix the ID given here with `oidc-`
+        "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
+        "idp_name": {"type": "string"},
+        "idp_icon": {"type": "string"},
+        "discover": {"type": "boolean"},
+        "issuer": {"type": "string"},
+        "client_id": {"type": "string"},
+        "client_secret": {"type": "string"},
+        "client_auth_method": {
+            "type": "string",
+            # the following list is the same as the keys of
+            # authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
+            # to avoid importing authlib here.
+            "enum": ["client_secret_basic", "client_secret_post", "none"],
+        },
+        "scopes": {"type": "array", "items": {"type": "string"}},
+        "authorization_endpoint": {"type": "string"},
+        "token_endpoint": {"type": "string"},
+        "userinfo_endpoint": {"type": "string"},
+        "jwks_uri": {"type": "string"},
+        "skip_verification": {"type": "boolean"},
+        "user_profile_method": {
+            "type": "string",
+            "enum": ["auto", "userinfo_endpoint"],
+        },
+        "allow_existing_users": {"type": "boolean"},
+        "user_mapping_provider": {"type": ["object", "null"]},
+    },
+}
+
+# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
+OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
+    "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
+}
+
+
+# the `oidc_providers` list can either be None (as it is in the default config), or
+# a list of provider configs, each of which requires an explicit ID and name.
+OIDC_PROVIDER_LIST_SCHEMA = {
+    "oneOf": [
+        {"type": "null"},
+        {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
+    ]
+}
+
+# the `oidc_config` setting can either be None (which it used to be in the default
+# config), or an object. If an object, it is ignored unless it has an "enabled: True"
+# property.
+#
+# It's *possible* to represent this with jsonschema, but the resultant errors aren't
+# particularly clear, so we just check for either an object or a null here, and do
+# additional checks in the code.
+OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
+
+# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
+MAIN_CONFIG_SCHEMA = {
+    "type": "object",
+    "properties": {
+        "oidc_config": OIDC_CONFIG_SCHEMA,
+        "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
+    },
+}
+
+
+def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
+    """extract and parse the OIDC provider configs from the config dict
+
+    The configuration may contain either a single `oidc_config` object with an
+    `enabled: True` property, or a list of provider configurations under
+    `oidc_providers`, *or both*.
+
+    Returns a generator which yields the OidcProviderConfig objects
+    """
+    validate_config(MAIN_CONFIG_SCHEMA, config, ())
+
+    for i, p in enumerate(config.get("oidc_providers") or []):
+        yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,)))
+
+    # for backwards-compatibility, it is also possible to provide a single "oidc_config"
+    # object with an "enabled: True" property.
+    oidc_config = config.get("oidc_config")
+    if oidc_config and oidc_config.get("enabled", False):
+        # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
+        # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
+        # above), so now we need to validate it.
+        validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
+        yield _parse_oidc_config_dict(oidc_config, ("oidc_config",))
+
+
+def _parse_oidc_config_dict(
+    oidc_config: JsonDict, config_path: Tuple[str, ...]
+) -> "OidcProviderConfig":
+    """Take the configuration dict and parse it into an OidcProviderConfig
+
+    Raises:
+        ConfigError if the configuration is malformed.
+    """
+    ump_config = oidc_config.get("user_mapping_provider", {})
+    ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+    ump_config.setdefault("config", {})
+
+    (user_mapping_provider_class, user_mapping_provider_config,) = load_module(
+        ump_config, config_path + ("user_mapping_provider",)
+    )
+
+    # Ensure loaded user mapping module has defined all necessary methods
+    required_methods = [
+        "get_remote_user_id",
+        "map_user_attributes",
+    ]
+    missing_methods = [
+        method
+        for method in required_methods
+        if not hasattr(user_mapping_provider_class, method)
+    ]
+    if missing_methods:
+        raise ConfigError(
+            "Class %s is missing required "
+            "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
+            config_path + ("user_mapping_provider", "module"),
+        )
 
-          # the oauth2 token endpoint. Required if provider discovery is disabled.
-          #
-          #token_endpoint: "https://accounts.example.com/oauth2/token"
+    # MSC2858 will apply certain limits in what can be used as an IdP id, so let's
+    # enforce those limits now.
+    # TODO: factor out this stuff to a generic function
+    idp_id = oidc_config.get("idp_id", "oidc")
 
-          # the OIDC userinfo endpoint. Required if discovery is disabled and the
-          # "openid" scope is not requested.
-          #
-          #userinfo_endpoint: "https://accounts.example.com/userinfo"
+    # TODO: update this validity check based on what MSC2858 decides.
+    valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._")
 
-          # URI where to fetch the JWKS. Required if discovery is disabled and the
-          # "openid" scope is used.
-          #
-          #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+    if any(c not in valid_idp_chars for c in idp_id):
+        raise ConfigError(
+            'idp_id may only contain a-z, 0-9, "-", ".", "_"',
+            config_path + ("idp_id",),
+        )
 
-          # Uncomment to skip metadata verification. Defaults to false.
-          #
-          # Use this if you are connecting to a provider that is not OpenID Connect
-          # compliant.
-          # Avoid this in production.
-          #
-          #skip_verification: true
+    if idp_id[0] not in string.ascii_lowercase:
+        raise ConfigError(
+            "idp_id must start with a-z", config_path + ("idp_id",),
+        )
 
-          # Whether to fetch the user profile from the userinfo endpoint. Valid
-          # values are: "auto" or "userinfo_endpoint".
-          #
-          # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
-          # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
-          #
-          #user_profile_method: "userinfo_endpoint"
+    # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
+    # clashes with other mechs (such as SAML, CAS).
+    #
+    # We allow "oidc" as an exception so that people migrating from old-style
+    # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to
+    # a new-style "oidc_providers" entry without changing the idp_id for their provider
+    # (and thereby invalidating their user_external_ids data).
 
-          # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
-          # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
-          #
-          #allow_existing_users: true
+    if idp_id != "oidc":
+        idp_id = "oidc-" + idp_id
 
-          # An external module can be provided here as a custom solution to mapping
-          # attributes returned from a OIDC provider onto a matrix user.
-          #
-          user_mapping_provider:
-            # The custom module's class. Uncomment to use a custom module.
-            # Default is {mapping_provider!r}.
-            #
-            # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
-            # for information on implementing a custom mapping provider.
-            #
-            #module: mapping_provider.OidcMappingProvider
-
-            # Custom configuration values for the module. This section will be passed as
-            # a Python dictionary to the user mapping provider module's `parse_config`
-            # method.
-            #
-            # The examples below are intended for the default provider: they should be
-            # changed if using a custom provider.
-            #
-            config:
-              # name of the claim containing a unique identifier for the user.
-              # Defaults to `sub`, which OpenID Connect compliant providers should provide.
-              #
-              #subject_claim: "sub"
-
-              # Jinja2 template for the localpart of the MXID.
-              #
-              # When rendering, this template is given the following variables:
-              #   * user: The claims returned by the UserInfo Endpoint and/or in the ID
-              #     Token
-              #
-              # If this is not set, the user will be prompted to choose their
-              # own username.
-              #
-              #localpart_template: "{{{{ user.preferred_username }}}}"
-
-              # Jinja2 template for the display name to set on first login.
-              #
-              # If unset, no displayname will be set.
-              #
-              #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
-
-              # Jinja2 templates for extra attributes to send back to the client during
-              # login.
-              #
-              # Note that these are non-standard and clients will ignore them without modifications.
-              #
-              #extra_attributes:
-                #birthdate: "{{{{ user.birthdate }}}}"
-        """.format(
-            mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
-        )
+    # MSC2858 also specifies that the idp_icon must be a valid MXC uri
+    idp_icon = oidc_config.get("idp_icon")
+    if idp_icon is not None:
+        try:
+            parse_and_validate_mxc_uri(idp_icon)
+        except ValueError as e:
+            raise ConfigError(
+                "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
+            ) from e
+
+    return OidcProviderConfig(
+        idp_id=idp_id,
+        idp_name=oidc_config.get("idp_name", "OIDC"),
+        idp_icon=idp_icon,
+        discover=oidc_config.get("discover", True),
+        issuer=oidc_config["issuer"],
+        client_id=oidc_config["client_id"],
+        client_secret=oidc_config["client_secret"],
+        client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
+        scopes=oidc_config.get("scopes", ["openid"]),
+        authorization_endpoint=oidc_config.get("authorization_endpoint"),
+        token_endpoint=oidc_config.get("token_endpoint"),
+        userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
+        jwks_uri=oidc_config.get("jwks_uri"),
+        skip_verification=oidc_config.get("skip_verification", False),
+        user_profile_method=oidc_config.get("user_profile_method", "auto"),
+        allow_existing_users=oidc_config.get("allow_existing_users", False),
+        user_mapping_provider_class=user_mapping_provider_class,
+        user_mapping_provider_config=user_mapping_provider_config,
+    )
+
+
+@attr.s(slots=True, frozen=True)
+class OidcProviderConfig:
+    # a unique identifier for this identity provider. Used in the 'user_external_ids'
+    # table, as well as the query/path parameter used in the login protocol.
+    idp_id = attr.ib(type=str)
+
+    # user-facing name for this identity provider.
+    idp_name = attr.ib(type=str)
+
+    # Optional MXC URI for icon for this IdP.
+    idp_icon = attr.ib(type=Optional[str])
+
+    # whether the OIDC discovery mechanism is used to discover endpoints
+    discover = attr.ib(type=bool)
+
+    # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
+    # discover the provider's endpoints.
+    issuer = attr.ib(type=str)
+
+    # oauth2 client id to use
+    client_id = attr.ib(type=str)
+
+    # oauth2 client secret to use
+    client_secret = attr.ib(type=str)
+
+    # auth method to use when exchanging the token.
+    # Valid values are 'client_secret_basic', 'client_secret_post' and
+    # 'none'.
+    client_auth_method = attr.ib(type=str)
+
+    # list of scopes to request
+    scopes = attr.ib(type=Collection[str])
+
+    # the oauth2 authorization endpoint. Required if discovery is disabled.
+    authorization_endpoint = attr.ib(type=Optional[str])
+
+    # the oauth2 token endpoint. Required if discovery is disabled.
+    token_endpoint = attr.ib(type=Optional[str])
+
+    # the OIDC userinfo endpoint. Required if discovery is disabled and the
+    # "openid" scope is not requested.
+    userinfo_endpoint = attr.ib(type=Optional[str])
+
+    # URI where to fetch the JWKS. Required if discovery is disabled and the
+    # "openid" scope is used.
+    jwks_uri = attr.ib(type=Optional[str])
+
+    # Whether to skip metadata verification
+    skip_verification = attr.ib(type=bool)
+
+    # Whether to fetch the user profile from the userinfo endpoint. Valid
+    # values are: "auto" or "userinfo_endpoint".
+    user_profile_method = attr.ib(type=str)
+
+    # whether to allow a user logging in via OIDC to match a pre-existing account
+    # instead of failing
+    allow_existing_users = attr.ib(type=bool)
+
+    # the class of the user mapping provider
+    user_mapping_provider_class = attr.ib(type=Type)
+
+    # the config of the user mapping provider
+    user_mapping_provider_config = attr.ib()
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index cc5f75123c..4bfc69cb7a 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -14,14 +14,13 @@
 # limitations under the License.
 
 import os
-from distutils.util import strtobool
 
 import pkg_resources
 
 from synapse.api.constants import RoomCreationPreset
 from synapse.config._base import Config, ConfigError
 from synapse.types import RoomAlias, UserID
-from synapse.util.stringutils import random_string_with_symbols
+from synapse.util.stringutils import random_string_with_symbols, strtobool
 
 
 class AccountValidityConfig(Config):
@@ -50,10 +49,6 @@ class AccountValidityConfig(Config):
 
             self.startup_job_max_delta = self.period * 10.0 / 100.0
 
-        if self.renew_by_email_enabled:
-            if "public_baseurl" not in synapse_config:
-                raise ConfigError("Can't send renewal emails without 'public_baseurl'")
-
         template_dir = config.get("template_dir")
 
         if not template_dir:
@@ -86,12 +81,12 @@ class RegistrationConfig(Config):
     section = "registration"
 
     def read_config(self, config, **kwargs):
-        self.enable_registration = bool(
-            strtobool(str(config.get("enable_registration", False)))
+        self.enable_registration = strtobool(
+            str(config.get("enable_registration", False))
         )
         if "disable_registration" in config:
-            self.enable_registration = not bool(
-                strtobool(str(config["disable_registration"]))
+            self.enable_registration = not strtobool(
+                str(config["disable_registration"])
             )
 
         self.account_validity = AccountValidityConfig(
@@ -110,13 +105,6 @@ class RegistrationConfig(Config):
         account_threepid_delegates = config.get("account_threepid_delegates") or {}
         self.account_threepid_delegate_email = account_threepid_delegates.get("email")
         self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
-        if self.account_threepid_delegate_msisdn and not self.public_baseurl:
-            raise ConfigError(
-                "The configuration option `public_baseurl` is required if "
-                "`account_threepid_delegate.msisdn` is set, such that "
-                "clients know where to submit validation tokens to. Please "
-                "configure `public_baseurl`."
-            )
 
         self.default_identity_server = config.get("default_identity_server")
         self.allow_guest_access = config.get("allow_guest_access", False)
@@ -241,8 +229,9 @@ class RegistrationConfig(Config):
           # send an email to the account's email address with a renewal link. By
           # default, no such emails are sent.
           #
-          # If you enable this setting, you will also need to fill out the 'email' and
-          # 'public_baseurl' configuration sections.
+          # If you enable this setting, you will also need to fill out the 'email'
+          # configuration section. You should also check that 'public_baseurl' is set
+          # correctly.
           #
           #renew_at: 1w
 
@@ -333,8 +322,7 @@ class RegistrationConfig(Config):
         # The identity server which we suggest that clients should use when users log
         # in on this server.
         #
-        # (By default, no suggestion is made, so it is left up to the client.
-        # This setting is ignored unless public_baseurl is also set.)
+        # (By default, no suggestion is made, so it is left up to the client.)
         #
         #default_identity_server: https://matrix.org
 
@@ -359,8 +347,6 @@ class RegistrationConfig(Config):
         # by the Matrix Identity Service API specification:
         # https://matrix.org/docs/spec/identity_service/latest
         #
-        # If a delegate is specified, the config option public_baseurl must also be filled out.
-        #
         account_threepid_delegates:
             #email: https://example.com     # Delegate email sending to example.com
             #msisdn: http://localhost:8090  # Delegate SMS sending to this local process
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 7b97d4f114..f33dfa0d6a 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -189,8 +189,6 @@ class SAML2Config(Config):
         import saml2
 
         public_baseurl = self.public_baseurl
-        if public_baseurl is None:
-            raise ConfigError("saml2_config requires a public_baseurl to be set")
 
         if self.saml2_grandfathered_mxid_source_attribute:
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7242a4aa8e..47a0370173 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -26,7 +26,7 @@ import yaml
 from netaddr import IPSet
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.http.endpoint import parse_and_validate_server_name
+from synapse.util.stringutils import parse_and_validate_server_name
 
 from ._base import Config, ConfigError
 
@@ -161,7 +161,11 @@ class ServerConfig(Config):
         self.print_pidfile = config.get("print_pidfile")
         self.user_agent_suffix = config.get("user_agent_suffix")
         self.use_frozen_dicts = config.get("use_frozen_dicts", False)
-        self.public_baseurl = config.get("public_baseurl")
+        self.public_baseurl = config.get("public_baseurl") or "https://%s/" % (
+            self.server_name,
+        )
+        if self.public_baseurl[-1] != "/":
+            self.public_baseurl += "/"
 
         # Whether to enable user presence.
         self.use_presence = config.get("use_presence", True)
@@ -317,9 +321,6 @@ class ServerConfig(Config):
         # Always blacklist 0.0.0.0, ::
         self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
 
-        if self.public_baseurl is not None:
-            if self.public_baseurl[-1] != "/":
-                self.public_baseurl += "/"
         self.start_pushers = config.get("start_pushers", True)
 
         # (undocumented) option for torturing the worker-mode replication a bit,
@@ -740,11 +741,16 @@ class ServerConfig(Config):
         #
         #web_client_location: https://riot.example.com/
 
-        # The public-facing base URL that clients use to access this HS
-        # (not including _matrix/...). This is the same URL a user would
-        # enter into the 'custom HS URL' field on their client. If you
-        # use synapse with a reverse proxy, this should be the URL to reach
-        # synapse via the proxy.
+        # The public-facing base URL that clients use to access this Homeserver (not
+        # including _matrix/...). This is the same URL a user might enter into the
+        # 'Custom Homeserver URL' field on their client. If you use Synapse with a
+        # reverse proxy, this should be the URL to reach Synapse via the proxy.
+        # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
+        # 'listeners' below).
+        #
+        # If this is left unset, it defaults to 'https://<server_name>/'. (Note that
+        # that will not work unless you configure Synapse or a reverse-proxy to listen
+        # on port 443.)
         #
         #public_baseurl: https://example.com/
 
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 1aeb1c5c92..59be825532 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -37,6 +37,7 @@ class SSOConfig(Config):
             self.sso_error_template,
             sso_account_deactivated_template,
             sso_auth_success_template,
+            self.sso_auth_bad_user_template,
         ) = self.read_templates(
             [
                 "sso_login_idp_picker.html",
@@ -45,6 +46,7 @@ class SSOConfig(Config):
                 "sso_error.html",
                 "sso_account_deactivated.html",
                 "sso_auth_success.html",
+                "sso_auth_bad_user.html",
             ],
             template_dir,
         )
@@ -62,11 +64,8 @@ class SSOConfig(Config):
         # gracefully to the client). This would make it pointless to ask the user for
         # confirmation, since the URL the confirmation page would be showing wouldn't be
         # the client's.
-        # public_baseurl is an optional setting, so we only add the fallback's URL to the
-        # list if it's provided (because we can't figure out what that URL is otherwise).
-        if self.public_baseurl:
-            login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
-            self.sso_client_whitelist.append(login_fallback_url)
+        login_fallback_url = self.public_baseurl + "_matrix/static/client/login"
+        self.sso_client_whitelist.append(login_fallback_url)
 
     def generate_config_section(self, **kwargs):
         return """\
@@ -84,9 +83,9 @@ class SSOConfig(Config):
             # phishing attacks from evil.site. To avoid this, include a slash after the
             # hostname: "https://my.client/".
             #
-            # If public_baseurl is set, then the login fallback page (used by clients
-            # that don't natively support the required login flows) is whitelisted in
-            # addition to any URLs in this list.
+            # The login fallback page (used by clients that don't natively support the
+            # required login flows) is automatically whitelisted in addition to any URLs
+            # in this list.
             #
             # By default, this list is empty.
             #
@@ -160,6 +159,14 @@ class SSOConfig(Config):
             #
             #   This template has no additional variables.
             #
+            # * HTML page shown after a user-interactive authentication session which
+            #   does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+            #
+            #   When rendering, this template is given the following variables:
+            #     * server_name: the homeserver's name.
+            #     * user_id_to_verify: the MXID of the user that we are trying to
+            #       validate.
+            #
             # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
             #   attempts to login: 'sso_account_deactivated.html'.
             #
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 364583f48b..f10e33f7b8 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -56,6 +56,12 @@ class WriterLocations:
     to_device = attr.ib(
         default=["master"], type=List[str], converter=_instance_to_list_converter,
     )
+    account_data = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
+    receipts = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
 
 
 class WorkerConfig(Config):
@@ -127,7 +133,7 @@ class WorkerConfig(Config):
 
         # Check that the configured writers for events and typing also appears in
         # `instance_map`.
-        for stream in ("events", "typing", "to_device"):
+        for stream in ("events", "typing", "to_device", "account_data", "receipts"):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
                 if instance != "master" and instance not in self.instance_map:
@@ -141,6 +147,16 @@ class WorkerConfig(Config):
                 "Must only specify one instance to handle `to_device` messages."
             )
 
+        if len(self.writers.account_data) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `account_data` messages."
+            )
+
+        if len(self.writers.receipts) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `receipts` messages."
+            )
+
         self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
         # Whether this worker should run background tasks or not.
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8028663fa8..3ec4120f85 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -17,7 +17,6 @@
 
 import abc
 import os
-from distutils.util import strtobool
 from typing import Dict, Optional, Tuple, Type
 
 from unpaddedbase64 import encode_base64
@@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
+from synapse.util.stringutils import strtobool
 
 # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
 # bugs where we accidentally share e.g. signature dicts. However, converting a
@@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
 # NOTE: This is overridden by the configuration by the Synapse worker apps, but
 # for the sake of tests, it is set here while it cannot be configured on the
 # homeserver object itself.
+
 USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
 
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 302b2f69bc..d330ae5dbc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,6 +18,7 @@ import copy
 import itertools
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -26,7 +27,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Sequence,
     Tuple,
     TypeVar,
     Union,
@@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
 
 
 class FederationClient(FederationBase):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.pdu_destination_tried = {}
+        self.pdu_destination_tried = {}  # type: Dict[str, Dict[str, int]]
         self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
         self.state = hs.get_state_handler()
         self.transport_layer = hs.get_federation_transport_client()
@@ -116,33 +119,32 @@ class FederationClient(FederationBase):
                 self.pdu_destination_tried[event_id] = destination_dict
 
     @log_function
-    def make_query(
+    async def make_query(
         self,
-        destination,
-        query_type,
-        args,
-        retry_on_dns_fail=False,
-        ignore_backoff=False,
-    ):
+        destination: str,
+        query_type: str,
+        args: dict,
+        retry_on_dns_fail: bool = False,
+        ignore_backoff: bool = False,
+    ) -> JsonDict:
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            query_type (str): Category of the query type; should match the
+            destination: Domain name of the remote homeserver
+            query_type: Category of the query type; should match the
                 handler name used in register_query_handler().
-            args (dict): Mapping of strings to strings containing the details
+            args: Mapping of strings to strings containing the details
                 of the query request.
-            ignore_backoff (bool): true to ignore the historical backoff data
+            ignore_backoff: true to ignore the historical backoff data
                 and try the request anyway.
 
         Returns:
-            a Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels(query_type).inc()
 
-        return self.transport_layer.make_query(
+        return await self.transport_layer.make_query(
             destination,
             query_type,
             args,
@@ -151,42 +153,52 @@ class FederationClient(FederationBase):
         )
 
     @log_function
-    def query_client_keys(self, destination, content, timeout):
+    async def query_client_keys(
+        self, destination: str, content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Query device keys for a device hosted on a remote server.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            content (dict): The query content.
+            destination: Domain name of the remote homeserver
+            content: The query content.
 
         Returns:
-            an Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels("client_device_keys").inc()
-        return self.transport_layer.query_client_keys(destination, content, timeout)
+        return await self.transport_layer.query_client_keys(
+            destination, content, timeout
+        )
 
     @log_function
-    def query_user_devices(self, destination, user_id, timeout=30000):
+    async def query_user_devices(
+        self, destination: str, user_id: str, timeout: int = 30000
+    ) -> JsonDict:
         """Query the device keys for a list of user ids hosted on a remote
         server.
         """
         sent_queries_counter.labels("user_devices").inc()
-        return self.transport_layer.query_user_devices(destination, user_id, timeout)
+        return await self.transport_layer.query_user_devices(
+            destination, user_id, timeout
+        )
 
     @log_function
-    def claim_client_keys(self, destination, content, timeout):
+    async def claim_client_keys(
+        self, destination: str, content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            content (dict): The query content.
+            destination: Domain name of the remote homeserver
+            content: The query content.
 
         Returns:
-            an Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
-        return self.transport_layer.claim_client_keys(destination, content, timeout)
+        return await self.transport_layer.claim_client_keys(
+            destination, content, timeout
+        )
 
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -195,10 +207,10 @@ class FederationClient(FederationBase):
         given destination server.
 
         Args:
-            dest (str): The remote homeserver to ask.
-            room_id (str): The room_id to backfill.
-            limit (int): The maximum number of events to return.
-            extremities (list): our current backwards extremities, to backfill from
+            dest: The remote homeserver to ask.
+            room_id: The room_id to backfill.
+            limit: The maximum number of events to return.
+            extremities: our current backwards extremities, to backfill from
         """
         logger.debug("backfill extrem=%s", extremities)
 
@@ -370,7 +382,7 @@ class FederationClient(FederationBase):
                 for events that have failed their checks
 
         Returns:
-            Deferred : A list of PDUs that have valid signatures and hashes.
+            A list of PDUs that have valid signatures and hashes.
         """
         deferreds = self._check_sigs_and_hashes(room_version, pdus)
 
@@ -418,7 +430,9 @@ class FederationClient(FederationBase):
         else:
             return [p for p in valid_pdus if p]
 
-    async def get_event_auth(self, destination, room_id, event_id):
+    async def get_event_auth(
+        self, destination: str, room_id: str, event_id: str
+    ) -> List[EventBase]:
         res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
 
         room_version = await self.store.get_room_version(room_id)
@@ -700,18 +714,16 @@ class FederationClient(FederationBase):
 
         return await self._try_destination_list("send_join", destinations, send_request)
 
-    async def _do_send_join(self, destination: str, pdu: EventBase):
+    async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_join_v2(
+            return await self.transport_layer.send_join_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
-
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -769,7 +781,7 @@ class FederationClient(FederationBase):
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_invite_v2(
+            return await self.transport_layer.send_invite_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
@@ -779,7 +791,6 @@ class FederationClient(FederationBase):
                     "invite_room_state": pdu.unsigned.get("invite_room_state", []),
                 },
             )
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -842,18 +853,16 @@ class FederationClient(FederationBase):
             "send_leave", destinations, send_request
         )
 
-    async def _do_send_leave(self, destination, pdu):
+    async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_leave_v2(
+            return await self.transport_layer.send_leave_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
-
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -879,7 +888,7 @@ class FederationClient(FederationBase):
         # content.
         return resp[1]
 
-    def get_public_rooms(
+    async def get_public_rooms(
         self,
         remote_server: str,
         limit: Optional[int] = None,
@@ -887,7 +896,7 @@ class FederationClient(FederationBase):
         search_filter: Optional[Dict] = None,
         include_all_networks: bool = False,
         third_party_instance_id: Optional[str] = None,
-    ):
+    ) -> JsonDict:
         """Get the list of public rooms from a remote homeserver
 
         Args:
@@ -901,8 +910,7 @@ class FederationClient(FederationBase):
                 party instance
 
         Returns:
-            Awaitable[Dict[str, Any]]: The response from the remote server, or None if
-            `remote_server` is the same as the local server_name
+            The response from the remote server.
 
         Raises:
             HttpResponseException: There was an exception returned from the remote server
@@ -910,7 +918,7 @@ class FederationClient(FederationBase):
                 requests over federation
 
         """
-        return self.transport_layer.get_public_rooms(
+        return await self.transport_layer.get_public_rooms(
             remote_server,
             limit,
             since_token,
@@ -923,7 +931,7 @@ class FederationClient(FederationBase):
         self,
         destination: str,
         room_id: str,
-        earliest_events_ids: Sequence[str],
+        earliest_events_ids: Iterable[str],
         latest_events: Iterable[EventBase],
         limit: int,
         min_depth: int,
@@ -974,7 +982,9 @@ class FederationClient(FederationBase):
 
         return signed_events
 
-    async def forward_third_party_invite(self, destinations, room_id, event_dict):
+    async def forward_third_party_invite(
+        self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
+    ) -> None:
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -983,7 +993,7 @@ class FederationClient(FederationBase):
                 await self.transport_layer.exchange_third_party_invite(
                     destination=destination, room_id=room_id, event_dict=event_dict
                 )
-                return None
+                return
             except CodeMessageException:
                 raise
             except Exception as e:
@@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
 
     async def get_room_complexity(
         self, destination: str, room_id: str
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """
         Fetch the complexity of a remote room from another server.
 
@@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
             could not fetch the complexity.
         """
         try:
-            complexity = await self.transport_layer.get_room_complexity(
+            return await self.transport_layer.get_room_complexity(
                 destination=destination, room_id=room_id
             )
-            return complexity
         except CodeMessageException as e:
             # We didn't manage to get it -- probably a 404. We are okay if other
             # servers don't give it to us.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e5339aca23..171d25c945 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -49,7 +49,6 @@ from synapse.events import EventBase
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
-from synapse.http.endpoint import parse_server_name
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
     make_deferred_yieldable,
@@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_server_name
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index cfd094e58f..95c64510a9 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -28,7 +28,6 @@ from synapse.api.urls import (
     FEDERATION_V1_PREFIX,
     FEDERATION_V2_PREFIX,
 )
-from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_boolean_from_args,
@@ -45,6 +44,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.server import HomeServer
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.stringutils import parse_and_validate_server_name
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..b1a5df9638 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,14 +13,157 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import random
 from typing import TYPE_CHECKING, List, Tuple
 
+from synapse.replication.http.account_data import (
+    ReplicationAddTagRestServlet,
+    ReplicationRemoveTagRestServlet,
+    ReplicationRoomAccountDataRestServlet,
+    ReplicationUserAccountDataRestServlet,
+)
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
 
+class AccountDataHandler:
+    def __init__(self, hs: "HomeServer"):
+        self._store = hs.get_datastore()
+        self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
+
+        self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+        self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+        self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+        self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+        self._account_data_writers = hs.config.worker.writers.account_data
+
+    async def add_account_data_to_room(
+        self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_to_room(
+                user_id, room_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+
+            return max_stream_id
+        else:
+            response = await self._room_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_account_data_for_user(
+        self, user_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_for_user(
+                user_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._user_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_tag_to_room(
+        self, user_id: str, room_id: str, tag: str, content: JsonDict
+    ) -> int:
+        """Add a tag to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            tag: The tag name to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_tag_to_room(
+                user_id, room_id, tag, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._add_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+        """Remove a tag from a room for a user.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.remove_tag_from_room(
+                user_id, room_id, tag
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._remove_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+            )
+            return response["max_stream_id"]
+
+
 class AccountDataEventSource:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4434673dc..0e98db22b3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -49,8 +49,13 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+    INTERACTIVE_AUTH_CHECKERS,
+    UIAuthSessionDataConstants,
+)
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
@@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
-from ._base import BaseHandler
-
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
@@ -260,10 +263,6 @@ class AuthHandler(BaseHandler):
         # authenticating for an operation to occur on their account.
         self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
 
-        # The following template is shown after a successful user interactive
-        # authentication session. It tells the user they can close the window.
-        self._sso_auth_success_template = hs.config.sso_auth_success_template
-
         # The following template is shown during the SSO authentication process if
         # the account is deactivated.
         self._sso_account_deactivated_template = (
@@ -284,7 +283,6 @@ class AuthHandler(BaseHandler):
         requester: Requester,
         request: SynapseRequest,
         request_body: Dict[str, Any],
-        clientip: str,
         description: str,
     ) -> Tuple[dict, Optional[str]]:
         """
@@ -301,8 +299,6 @@ class AuthHandler(BaseHandler):
 
             request_body: The body of the request sent by the client
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
@@ -338,10 +334,10 @@ class AuthHandler(BaseHandler):
                 request_body.pop("auth", None)
                 return request_body, None
 
-        user_id = requester.user.to_string()
+        requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
 
         # build a list of supported flows
         supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -349,13 +345,16 @@ class AuthHandler(BaseHandler):
         )
         flows = [[login_type] for login_type in supported_ui_auth_types]
 
+        def get_new_session_data() -> JsonDict:
+            return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
+
         try:
             result, params, session_id = await self.check_ui_auth(
-                flows, request, request_body, clientip, description
+                flows, request, request_body, description, get_new_session_data,
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
             raise
 
         # find the completed login type
@@ -363,14 +362,14 @@ class AuthHandler(BaseHandler):
             if login_type not in result:
                 continue
 
-            user_id = result[login_type]
+            validated_user_id = result[login_type]
             break
         else:
             # this can't happen
             raise Exception("check_auth returned True but no successful login type")
 
         # check that the UI auth matched the access token
-        if user_id != requester.user.to_string():
+        if validated_user_id != requester_user_id:
             raise AuthError(403, "Invalid auth")
 
         # Note that the access token has been validated.
@@ -402,13 +401,9 @@ class AuthHandler(BaseHandler):
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
-        if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
-            if await self.store.get_external_ids_by_user(user.to_string()):
-                ui_auth_types.add(LoginType.SSO)
-
-        # Our CAS impl does not (yet) correctly register users in user_external_ids,
-        # so always offer that if it's available.
-        if self.hs.config.cas.cas_enabled:
+        if await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user.to_string()
+        ):
             ui_auth_types.add(LoginType.SSO)
 
         return ui_auth_types
@@ -426,8 +421,8 @@ class AuthHandler(BaseHandler):
         flows: List[List[str]],
         request: SynapseRequest,
         clientdict: Dict[str, Any],
-        clientip: str,
         description: str,
+        get_new_session_data: Optional[Callable[[], JsonDict]] = None,
     ) -> Tuple[dict, dict, str]:
         """
         Takes a dictionary sent by the client in the login / registration
@@ -448,11 +443,16 @@ class AuthHandler(BaseHandler):
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
+            get_new_session_data:
+                an optional callback which will be called when starting a new session.
+                it should return data to be stored as part of the session.
+
+                The keys of the returned data should be entries in
+                UIAuthSessionDataConstants.
+
         Returns:
             A tuple of (creds, params, session_id).
 
@@ -480,10 +480,15 @@ class AuthHandler(BaseHandler):
 
         # If there's no session ID, create a new session.
         if not sid:
+            new_session_data = get_new_session_data() if get_new_session_data else {}
+
             session = await self.store.create_ui_auth_session(
                 clientdict, uri, method, description
             )
 
+            for k, v in new_session_data.items():
+                await self.set_session_data(session.session_id, k, v)
+
         else:
             try:
                 session = await self.store.get_ui_auth_session(sid)
@@ -539,7 +544,8 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.get_user_agent("")
+        user_agent = get_request_user_agent(request)
+        clientip = request.getClientIP()
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -644,7 +650,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key to store the data under. An entry from
+                UIAuthSessionDataConstants.
             value: The data to store
         """
         try:
@@ -660,7 +667,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key the data was stored under. An entry from
+                UIAuthSessionDataConstants.
             default: Value to return if the key has not been set
         """
         try:
@@ -1334,12 +1342,12 @@ class AuthHandler(BaseHandler):
         else:
             return False
 
-    async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+    async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
         """
         Get the HTML for the SSO redirect confirmation page.
 
         Args:
-            redirect_url: The URL to redirect to the SSO provider.
+            request: The incoming HTTP request
             session_id: The user interactive authentication session ID.
 
         Returns:
@@ -1349,30 +1357,38 @@ class AuthHandler(BaseHandler):
             session = await self.store.get_ui_auth_session(session_id)
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-        return self._sso_auth_confirm_template.render(
-            description=session.description, redirect_url=redirect_url,
+
+        user_id_to_verify = await self.get_session_data(
+            session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user_id_to_verify
         )
 
-    async def complete_sso_ui_auth(
-        self, registered_user_id: str, session_id: str, request: Request,
-    ):
-        """Having figured out a mxid for this user, complete the HTTP request
+        if not idps:
+            # we checked that the user had some remote identities before offering an SSO
+            # flow, so either it's been deleted or the client has requested SSO despite
+            # it not being offered.
+            raise SynapseError(400, "User has no SSO identities")
 
-        Args:
-            registered_user_id: The registered user ID to complete SSO login for.
-            session_id: The ID of the user-interactive auth session.
-            request: The request to complete.
-        """
-        # Mark the stage of the authentication as successful.
-        # Save the user who authenticated with SSO, this will be used to ensure
-        # that the account be modified is also the person who logged in.
-        await self.store.mark_ui_auth_stage_complete(
-            session_id, LoginType.SSO, registered_user_id
+        # for now, just pick one
+        idp_id, sso_auth_provider = next(iter(idps.items()))
+        if len(idps) > 0:
+            logger.warning(
+                "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+                "picking %r",
+                user_id_to_verify,
+                idp_id,
+            )
+
+        redirect_url = await sso_auth_provider.handle_redirect_request(
+            request, None, session_id
         )
 
-        # Render the HTML and return.
-        html = self._sso_auth_success_template
-        respond_with_html(request, 200, html)
+        return self._sso_auth_confirm_template.render(
+            description=session.description, redirect_url=redirect_url,
+        )
 
     async def complete_sso_login(
         self,
@@ -1488,8 +1504,8 @@ class AuthHandler(BaseHandler):
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({param_name: param})
+        query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+        query.append((param_name, param))
         url_parts[4] = urllib.parse.urlencode(query)
         return urllib.parse.urlunparse(url_parts)
 
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f3430c6713..0f342c607b 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -80,6 +80,10 @@ class CasHandler:
         # user-facing name of this auth provider
         self.idp_name = "CAS"
 
+        # we do not currently support icons for CAS auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+
         self._sso_handler = hs.get_sso_handler()
 
         self._sso_handler.register_identity_provider(self)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e808142365..c4a3b26a84 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
 
 from ._base import BaseHandler
 
@@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
         self._device_handler = hs.get_device_handler()
         self._room_member_handler = hs.get_room_member_handler()
         self._identity_handler = hs.get_identity_handler()
+        self._profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
         self._server_name = hs.hostname
 
@@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
         self._account_validity_enabled = hs.config.account_validity.enabled
 
     async def deactivate_account(
-        self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+        self,
+        user_id: str,
+        erase_data: bool,
+        requester: Requester,
+        id_server: Optional[str] = None,
+        by_admin: bool = False,
     ) -> bool:
         """Deactivate a user's account
 
         Args:
             user_id: ID of user to be deactivated
             erase_data: whether to GDPR-erase the user's data
+            requester: The user attempting to make this change.
             id_server: Use the given identity server when unbinding
                 any threepids. If None then will attempt to unbind using the
                 identity server specified when binding (if known).
+            by_admin: Whether this change was made by an administrator.
 
         Returns:
             True if identity server supports removing threepids, otherwise False.
@@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
 
         # Mark the user as erased, if they asked for that
         if erase_data:
+            user = UserID.from_string(user_id)
+            # Remove avatar URL from this user
+            await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+            # Remove displayname from this user
+            await self._profile_handler.set_displayname(user, requester, "", by_admin)
+
             logger.info("Marking %s as erased", user_id)
             await self.store.mark_user_erased(user_id)
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index fc974a82e8..0c7737e09d 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -163,7 +163,7 @@ class DeviceMessageHandler:
             await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
 
             # Immediately attempt a resync in the background
-            run_in_background(self._user_device_resync, sender_user_id)
+            run_in_background(self._user_device_resync, user_id=sender_user_id)
 
     async def send_device_message(
         self,
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c05036ad1f..f61844d688 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler):
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
-        assert self.hs.config.public_baseurl
-
         # we need to tell the client to send the token back to us, since it doesn't
         # otherwise know where to send it, so add submit_url response parameter
         # (see also MSC2078)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 6835c6c462..1607e12935 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import inspect
 import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
 from urllib.parse import urlencode
 
 import attr
@@ -35,7 +35,7 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
+from synapse.config.oidc_config import OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -71,6 +71,144 @@ JWK = Dict[str, str]
 JWKS = TypedDict("JWKS", {"keys": List[JWK]})
 
 
+class OidcHandler:
+    """Handles requests related to the OpenID Connect login flow.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._sso_handler = hs.get_sso_handler()
+
+        provider_confs = hs.config.oidc.oidc_providers
+        # we should not have been instantiated if there is no configured provider.
+        assert provider_confs
+
+        self._token_generator = OidcSessionTokenGenerator(hs)
+        self._providers = {
+            p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+        }  # type: Dict[str, OidcProvider]
+
+    async def load_metadata(self) -> None:
+        """Validate the config and load the metadata from the remote endpoint.
+
+        Called at startup to ensure we have everything we need.
+        """
+        for idp_id, p in self._providers.items():
+            try:
+                await p.load_metadata()
+                await p.load_jwks()
+            except Exception as e:
+                raise Exception(
+                    "Error while initialising OIDC provider %r" % (idp_id,)
+                ) from e
+
+    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/oidc/callback
+
+        Since we might want to display OIDC-related errors in a user-friendly
+        way, we don't raise SynapseError from here. Instead, we call
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+        Most of the OpenID Connect logic happens here:
+
+          - first, we check if there was any error returned by the provider and
+            display it
+          - then we fetch the session cookie, decode and verify it
+          - the ``state`` query parameter should match with the one stored in the
+            session cookie
+
+        Once we know the session is legit, we then delegate to the OIDC Provider
+        implementation, which will exchange the code with the provider and complete the
+        login/authentication.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+
+        # The provider might redirect with an error.
+        # In that case, just display it as-is.
+        if b"error" in request.args:
+            # error response from the auth server. see:
+            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+            error = request.args[b"error"][0].decode()
+            description = request.args.get(b"error_description", [b""])[0].decode()
+
+            # Most of the errors returned by the provider could be due by
+            # either the provider misbehaving or Synapse being misconfigured.
+            # The only exception of that is "access_denied", where the user
+            # probably cancelled the login flow. In other cases, log those errors.
+            if error != "access_denied":
+                logger.error("Error from the OIDC provider: %s %s", error, description)
+
+            self._sso_handler.render_error(request, error, description)
+            return
+
+        # otherwise, it is presumably a successful response. see:
+        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+        # Fetch the session cookie
+        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
+        if session is None:
+            logger.info("No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
+            return
+
+        # Remove the cookie. There is a good chance that if the callback failed
+        # once, it will fail next time and the code will already be exchanged.
+        # Removing it early avoids spamming the provider with token requests.
+        request.addCookie(
+            SESSION_COOKIE_NAME,
+            b"",
+            path="/_synapse/oidc",
+            expires="Thu, Jan 01 1970 00:00:00 UTC",
+            httpOnly=True,
+            sameSite="lax",
+        )
+
+        # Check for the state query parameter
+        if b"state" not in request.args:
+            logger.info("State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
+            return
+
+        state = request.args[b"state"][0].decode()
+
+        # Deserialize the session token and verify it.
+        try:
+            session_data = self._token_generator.verify_oidc_session_token(
+                session, state
+            )
+        except (MacaroonDeserializationException, ValueError) as e:
+            logger.exception("Invalid session")
+            self._sso_handler.render_error(request, "invalid_session", str(e))
+            return
+        except MacaroonInvalidSignatureException as e:
+            logger.exception("Could not verify session")
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
+            return
+
+        oidc_provider = self._providers.get(session_data.idp_id)
+        if not oidc_provider:
+            logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+            self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+            return
+
+        if b"code" not in request.args:
+            logger.info("Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
+            return
+
+        code = request.args[b"code"][0].decode()
+
+        await oidc_provider.handle_oidc_callback(request, session_data, code)
+
+
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint
     """
@@ -85,44 +223,56 @@ class OidcError(Exception):
         return self.error
 
 
-class OidcHandler(BaseHandler):
-    """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+    """Wraps the config for a single OIDC IdentityProvider
+
+    Provides methods for handling redirect requests and callbacks via that particular
+    IdP.
     """
 
-    def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+    def __init__(
+        self,
+        hs: "HomeServer",
+        token_generator: "OidcSessionTokenGenerator",
+        provider: OidcProviderConfig,
+    ):
+        self._store = hs.get_datastore()
+
+        self._token_generator = token_generator
+
         self._callback_url = hs.config.oidc_callback_url  # type: str
-        self._scopes = hs.config.oidc_scopes  # type: List[str]
-        self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
+
+        self._scopes = provider.scopes
+        self._user_profile_method = provider.user_profile_method
         self._client_auth = ClientAuth(
-            hs.config.oidc_client_id,
-            hs.config.oidc_client_secret,
-            hs.config.oidc_client_auth_method,
+            provider.client_id, provider.client_secret, provider.client_auth_method,
         )  # type: ClientAuth
-        self._client_auth_method = hs.config.oidc_client_auth_method  # type: str
+        self._client_auth_method = provider.client_auth_method
         self._provider_metadata = OpenIDProviderMetadata(
-            issuer=hs.config.oidc_issuer,
-            authorization_endpoint=hs.config.oidc_authorization_endpoint,
-            token_endpoint=hs.config.oidc_token_endpoint,
-            userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
-            jwks_uri=hs.config.oidc_jwks_uri,
+            issuer=provider.issuer,
+            authorization_endpoint=provider.authorization_endpoint,
+            token_endpoint=provider.token_endpoint,
+            userinfo_endpoint=provider.userinfo_endpoint,
+            jwks_uri=provider.jwks_uri,
         )  # type: OpenIDProviderMetadata
-        self._provider_needs_discovery = hs.config.oidc_discover  # type: bool
-        self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
-            hs.config.oidc_user_mapping_provider_config
-        )  # type: OidcMappingProvider
-        self._skip_verification = hs.config.oidc_skip_verification  # type: bool
-        self._allow_existing_users = hs.config.oidc_allow_existing_users  # type: bool
+        self._provider_needs_discovery = provider.discover
+        self._user_mapping_provider = provider.user_mapping_provider_class(
+            provider.user_mapping_provider_config
+        )
+        self._skip_verification = provider.skip_verification
+        self._allow_existing_users = provider.allow_existing_users
 
         self._http_client = hs.get_proxied_http_client()
         self._server_name = hs.config.server_name  # type: str
-        self._macaroon_secret_key = hs.config.macaroon_secret_key
 
         # identifier for the external_ids table
-        self.idp_id = "oidc"
+        self.idp_id = provider.idp_id
 
         # user-facing name of this auth provider
-        self.idp_name = "OIDC"
+        self.idp_name = provider.idp_name
+
+        # MXC URI for icon for this auth provider
+        self.idp_icon = provider.idp_icon
 
         self._sso_handler = hs.get_sso_handler()
 
@@ -519,11 +669,14 @@ class OidcHandler(BaseHandler):
         if not client_redirect_url:
             client_redirect_url = b""
 
-        cookie = self._generate_oidc_session_token(
+        cookie = self._token_generator.generate_oidc_session_token(
             state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url.decode(),
-            ui_auth_session_id=ui_auth_session_id,
+            session_data=OidcSessionData(
+                idp_id=self.idp_id,
+                nonce=nonce,
+                client_redirect_url=client_redirect_url.decode(),
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
         request.addCookie(
             SESSION_COOKIE_NAME,
@@ -546,22 +699,16 @@ class OidcHandler(BaseHandler):
             nonce=nonce,
         )
 
-    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+    async def handle_oidc_callback(
+        self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+    ) -> None:
         """Handle an incoming request to /_synapse/oidc/callback
 
-        Since we might want to display OIDC-related errors in a user-friendly
-        way, we don't raise SynapseError from here. Instead, we call
-        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+        By this time we have already validated the session on the synapse side, and
+        now need to do the provider-specific operations. This includes:
 
-        Most of the OpenID Connect logic happens here:
-
-          - first, we check if there was any error returned by the provider and
-            display it
-          - then we fetch the session cookie, decode and verify it
-          - the ``state`` query parameter should match with the one stored in the
-            session cookie
-          - once we known this session is legit, exchange the code with the
-            provider using the ``token_endpoint`` (see ``_exchange_code``)
+          - exchange the code with the provider using the ``token_endpoint`` (see
+            ``_exchange_code``)
           - once we have the token, use it to either extract the UserInfo from
             the ``id_token`` (``_parse_id_token``), or use the ``access_token``
             to fetch UserInfo from the ``userinfo_endpoint``
@@ -571,88 +718,12 @@ class OidcHandler(BaseHandler):
 
         Args:
             request: the incoming request from the browser.
+            session_data: the session data, extracted from our cookie
+            code: The authorization code we got from the callback.
         """
-
-        # The provider might redirect with an error.
-        # In that case, just display it as-is.
-        if b"error" in request.args:
-            # error response from the auth server. see:
-            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
-            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
-            error = request.args[b"error"][0].decode()
-            description = request.args.get(b"error_description", [b""])[0].decode()
-
-            # Most of the errors returned by the provider could be due by
-            # either the provider misbehaving or Synapse being misconfigured.
-            # The only exception of that is "access_denied", where the user
-            # probably cancelled the login flow. In other cases, log those errors.
-            if error != "access_denied":
-                logger.error("Error from the OIDC provider: %s %s", error, description)
-
-            self._sso_handler.render_error(request, error, description)
-            return
-
-        # otherwise, it is presumably a successful response. see:
-        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
-
-        # Fetch the session cookie
-        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
-        if session is None:
-            logger.info("No session cookie found")
-            self._sso_handler.render_error(
-                request, "missing_session", "No session cookie found"
-            )
-            return
-
-        # Remove the cookie. There is a good chance that if the callback failed
-        # once, it will fail next time and the code will already be exchanged.
-        # Removing it early avoids spamming the provider with token requests.
-        request.addCookie(
-            SESSION_COOKIE_NAME,
-            b"",
-            path="/_synapse/oidc",
-            expires="Thu, Jan 01 1970 00:00:00 UTC",
-            httpOnly=True,
-            sameSite="lax",
-        )
-
-        # Check for the state query parameter
-        if b"state" not in request.args:
-            logger.info("State parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "State parameter is missing"
-            )
-            return
-
-        state = request.args[b"state"][0].decode()
-
-        # Deserialize the session token and verify it.
-        try:
-            (
-                nonce,
-                client_redirect_url,
-                ui_auth_session_id,
-            ) = self._verify_oidc_session_token(session, state)
-        except MacaroonDeserializationException as e:
-            logger.exception("Invalid session")
-            self._sso_handler.render_error(request, "invalid_session", str(e))
-            return
-        except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session")
-            self._sso_handler.render_error(request, "mismatching_session", str(e))
-            return
-
         # Exchange the code with the provider
-        if b"code" not in request.args:
-            logger.info("Code parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "Code parameter is missing"
-            )
-            return
-
-        logger.debug("Exchanging code")
-        code = request.args[b"code"][0].decode()
         try:
+            logger.debug("Exchanging code")
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")
@@ -674,14 +745,14 @@ class OidcHandler(BaseHandler):
         else:
             logger.debug("Extracting userinfo from id_token")
             try:
-                userinfo = await self._parse_id_token(token, nonce=nonce)
+                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
         # first check if we're doing a UIA
-        if ui_auth_session_id:
+        if session_data.ui_auth_session_id:
             try:
                 remote_user_id = self._remote_id_from_userinfo(userinfo)
             except Exception as e:
@@ -690,7 +761,7 @@ class OidcHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self.idp_id, remote_user_id, ui_auth_session_id, request
+                self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
             )
 
         # otherwise, it's a login
@@ -698,133 +769,12 @@ class OidcHandler(BaseHandler):
         # Call the mapper to register/login the user
         try:
             await self._complete_oidc_login(
-                userinfo, token, request, client_redirect_url
+                userinfo, token, request, session_data.client_redirect_url
             )
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
 
-    def _generate_oidc_session_token(
-        self,
-        state: str,
-        nonce: str,
-        client_redirect_url: str,
-        ui_auth_session_id: Optional[str],
-        duration_in_ms: int = (60 * 60 * 1000),
-    ) -> str:
-        """Generates a signed token storing data about an OIDC session.
-
-        When Synapse initiates an authorization flow, it creates a random state
-        and a random nonce. Those parameters are given to the provider and
-        should be verified when the client comes back from the provider.
-        It is also used to store the client_redirect_url, which is used to
-        complete the SSO login flow.
-
-        Args:
-            state: The ``state`` parameter passed to the OIDC provider.
-            nonce: The ``nonce`` parameter passed to the OIDC provider.
-            client_redirect_url: The URL the client gave when it initiated the
-                flow.
-            ui_auth_session_id: The session ID of the ongoing UI Auth (or
-                None if this is a login).
-            duration_in_ms: An optional duration for the token in milliseconds.
-                Defaults to an hour.
-
-        Returns:
-            A signed macaroon token with the session information.
-        """
-        macaroon = pymacaroons.Macaroon(
-            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
-        )
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("type = session")
-        macaroon.add_first_party_caveat("state = %s" % (state,))
-        macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
-        macaroon.add_first_party_caveat(
-            "client_redirect_url = %s" % (client_redirect_url,)
-        )
-        if ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (ui_auth_session_id,)
-            )
-        now = self.clock.time_msec()
-        expiry = now + duration_in_ms
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
-        return macaroon.serialize()
-
-    def _verify_oidc_session_token(
-        self, session: bytes, state: str
-    ) -> Tuple[str, str, Optional[str]]:
-        """Verifies and extract an OIDC session token.
-
-        This verifies that a given session token was issued by this homeserver
-        and extract the nonce and client_redirect_url caveats.
-
-        Args:
-            session: The session token to verify
-            state: The state the OIDC provider gave back
-
-        Returns:
-            The nonce, client_redirect_url, and ui_auth_session_id for this session
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(session)
-
-        v = pymacaroons.Verifier()
-        v.satisfy_exact("gen = 1")
-        v.satisfy_exact("type = session")
-        v.satisfy_exact("state = %s" % (state,))
-        v.satisfy_general(lambda c: c.startswith("nonce = "))
-        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
-        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
-
-        v.verify(macaroon, self._macaroon_secret_key)
-
-        # Extract the `nonce`, `client_redirect_url`, and maybe the
-        # `ui_auth_session_id` from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
-        return nonce, client_redirect_url, ui_auth_session_id
-
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            Exception: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self.clock.time_msec()
-        return now < expiry
-
     async def _complete_oidc_login(
         self,
         userinfo: UserInfo,
@@ -901,8 +851,8 @@ class OidcHandler(BaseHandler):
                 # and attempt to match it.
                 attributes = await oidc_response_to_user_attributes(failures=0)
 
-                user_id = UserID(attributes.localpart, self.server_name).to_string()
-                users = await self.store.get_users_by_id_case_insensitive(user_id)
+                user_id = UserID(attributes.localpart, self._server_name).to_string()
+                users = await self._store.get_users_by_id_case_insensitive(user_id)
                 if users:
                     # If an existing matrix ID is returned, then use it.
                     if len(users) == 1:
@@ -954,6 +904,157 @@ class OidcHandler(BaseHandler):
         return str(remote_user_id)
 
 
+class OidcSessionTokenGenerator:
+    """Methods for generating and checking OIDC Session cookies."""
+
+    def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
+        self._server_name = hs.hostname
+        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
+
+    def generate_oidc_session_token(
+        self,
+        state: str,
+        session_data: "OidcSessionData",
+        duration_in_ms: int = (60 * 60 * 1000),
+    ) -> str:
+        """Generates a signed token storing data about an OIDC session.
+
+        When Synapse initiates an authorization flow, it creates a random state
+        and a random nonce. Those parameters are given to the provider and
+        should be verified when the client comes back from the provider.
+        It is also used to store the client_redirect_url, which is used to
+        complete the SSO login flow.
+
+        Args:
+            state: The ``state`` parameter passed to the OIDC provider.
+            session_data: data to include in the session token.
+            duration_in_ms: An optional duration for the token in milliseconds.
+                Defaults to an hour.
+
+        Returns:
+            A signed macaroon token with the session information.
+        """
+        macaroon = pymacaroons.Macaroon(
+            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+        )
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = session")
+        macaroon.add_first_party_caveat("state = %s" % (state,))
+        macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
+        macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
+        macaroon.add_first_party_caveat(
+            "client_redirect_url = %s" % (session_data.client_redirect_url,)
+        )
+        if session_data.ui_auth_session_id:
+            macaroon.add_first_party_caveat(
+                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+            )
+        now = self._clock.time_msec()
+        expiry = now + duration_in_ms
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+        return macaroon.serialize()
+
+    def verify_oidc_session_token(
+        self, session: bytes, state: str
+    ) -> "OidcSessionData":
+        """Verifies and extract an OIDC session token.
+
+        This verifies that a given session token was issued by this homeserver
+        and extract the nonce and client_redirect_url caveats.
+
+        Args:
+            session: The session token to verify
+            state: The state the OIDC provider gave back
+
+        Returns:
+            The data extracted from the session cookie
+
+        Raises:
+            ValueError if an expected caveat is missing from the macaroon.
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(session)
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = session")
+        v.satisfy_exact("state = %s" % (state,))
+        v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("idp_id = "))
+        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+        # to always satisfy this.
+        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+        v.satisfy_general(self._verify_expiry)
+
+        v.verify(macaroon, self._macaroon_secret_key)
+
+        # Extract the session data from the token.
+        nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
+        client_redirect_url = self._get_value_from_macaroon(
+            macaroon, "client_redirect_url"
+        )
+        try:
+            ui_auth_session_id = self._get_value_from_macaroon(
+                macaroon, "ui_auth_session_id"
+            )  # type: Optional[str]
+        except ValueError:
+            ui_auth_session_id = None
+
+        return OidcSessionData(
+            nonce=nonce,
+            idp_id=idp_id,
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=ui_auth_session_id,
+        )
+
+    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+        """Extracts a caveat value from a macaroon token.
+
+        Args:
+            macaroon: the token
+            key: the key of the caveat to extract
+
+        Returns:
+            The extracted value
+
+        Raises:
+            ValueError: if the caveat was not in the macaroon
+        """
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise ValueError("No %s caveat in macaroon" % (key,))
+
+    def _verify_expiry(self, caveat: str) -> bool:
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix) :])
+        now = self._clock.time_msec()
+        return now < expiry
+
+
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+    """The attributes which are stored in a OIDC session cookie"""
+
+    # the Identity Provider being used
+    idp_id = attr.ib(type=str)
+
+    # The `nonce` parameter passed to the OIDC provider.
+    nonce = attr.ib(type=str)
+
+    # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+    client_redirect_url = attr.ib(type=str)
+
+    # The session ID of the ongoing UI Auth (None if this is a login)
+    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+
+
 UserAttributeDict = TypedDict(
     "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
 )
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 36f9ee4b71..c02b951031 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
             )
 
+        avatar_url_to_set = new_avatar_url  # type: Optional[str]
+        if new_avatar_url == "":
+            avatar_url_to_set = None
+
         # Same like set_displayname
         if by_admin:
             requester = create_requester(
                 target_user, authenticated_entity=requester.authenticated_entity
             )
 
-        await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+        await self.store.set_profile_avatar_url(
+            target_user.localpart, avatar_url_to_set
+        )
 
         if self.hs.config.user_directory_search_all_users:
             profile = await self.store.get_profileinfo(target_user.localpart)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..6bb2fd936b 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
         super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
+        self.account_data_handler = hs.get_account_data_handler()
         self.read_marker_linearizer = Linearizer(name="read_marker")
-        self.notifier = hs.get_notifier()
 
     async def received_client_read_marker(
         self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
 
             if should_update:
                 content = {"event_id": event_id}
-                max_id = await self.store.add_account_data_to_room(
+                await self.account_data_handler.add_account_data_to_room(
                     user_id, room_id, "m.fully_read", content
                 )
-                self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a9abdf42e0..cc21fc2284 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.hs = hs
-        self.federation = hs.get_federation_sender()
-        hs.get_federation_registry().register_edu_handler(
-            "m.receipt", self._received_remote_receipt
-        )
+
+        # We only need to poke the federation sender explicitly if its on the
+        # same instance. Other federation sender instances will get notified by
+        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+        # in the receipts stream.
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
+        # If we can handle the receipt EDUs we do so, otherwise we route them
+        # to the appropriate worker.
+        if hs.get_instance_name() in hs.config.worker.writers.receipts:
+            hs.get_federation_registry().register_edu_handler(
+                "m.receipt", self._received_remote_receipt
+            )
+        else:
+            hs.get_federation_registry().register_instances_for_edu(
+                "m.receipt", hs.config.worker.writers.receipts,
+            )
+
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
@@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
         if not is_new:
             return
 
-        await self.federation.send_read_receipt(receipt)
+        if self.federation_sender:
+            await self.federation_sender.send_read_receipt(receipt)
 
 
 class ReceiptEventSource:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3bece6d668..ee27d99135 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
-from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
@@ -55,6 +54,7 @@ from synapse.types import (
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_and_validate_server_name
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index cb5a29bc7e..e001e418f9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.registration_handler = hs.get_registration_handler()
         self.profile_handler = hs.get_profile_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self.account_data_handler = hs.get_account_data_handler()
 
         self.member_linearizer = Linearizer(name="member")
 
@@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     direct_rooms[key].append(new_room_id)
 
                     # Save back to user's m.direct account data
-                    await self.store.add_account_data_for_user(
+                    await self.account_data_handler.add_account_data_for_user(
                         user_id, AccountDataTypes.DIRECT, direct_rooms
                     )
                     break
@@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Copy each room tag to the new room
         for tag, tag_content in room_tags.items():
-            await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+            await self.account_data_handler.add_tag_to_room(
+                user_id, new_room_id, tag, tag_content
+            )
 
     async def update_membership(
         self,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a8376543c9..38461cf79d 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -78,6 +78,10 @@ class SamlHandler(BaseHandler):
         # user-facing name of this auth provider
         self.idp_name = "SAML"
 
+        # we do not currently support icons for SAML auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 2da1ea2223..d493327a10 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -22,7 +22,10 @@ from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
 
+from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@@ -72,6 +75,11 @@ class SsoIdentityProvider(Protocol):
     def idp_name(self) -> str:
         """User-facing name for this provider"""
 
+    @property
+    def idp_icon(self) -> Optional[str]:
+        """Optional MXC URI for user-facing icon"""
+        return None
+
     @abc.abstractmethod
     async def handle_redirect_request(
         self,
@@ -145,8 +153,13 @@ class SsoHandler:
         self._store = hs.get_datastore()
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
-        self._error_template = hs.config.sso_error_template
         self._auth_handler = hs.get_auth_handler()
+        self._error_template = hs.config.sso_error_template
+        self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+        # The following template is shown after a successful user interactive
+        # authentication session. It tells the user they can close the window.
+        self._sso_auth_success_template = hs.config.sso_auth_success_template
 
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@@ -166,6 +179,37 @@ class SsoHandler:
         """Get the configured identity providers"""
         return self._identity_providers
 
+    async def get_identity_providers_for_user(
+        self, user_id: str
+    ) -> Mapping[str, SsoIdentityProvider]:
+        """Get the SsoIdentityProviders which a user has used
+
+        Given a user id, get the identity providers that that user has used to log in
+        with in the past (and thus could use to re-identify themselves for UI Auth).
+
+        Args:
+            user_id: MXID of user to look up
+
+        Raises:
+            a map of idp_id to SsoIdentityProvider
+        """
+        external_ids = await self._store.get_external_ids_by_user(user_id)
+
+        valid_idps = {}
+        for idp_id, _ in external_ids:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                logger.warning(
+                    "User %r has an SSO mapping for IdP %r, but this is no longer "
+                    "configured.",
+                    user_id,
+                    idp_id,
+                )
+            else:
+                valid_idps[idp_id] = idp
+
+        return valid_idps
+
     def render_error(
         self,
         request: Request,
@@ -362,7 +406,7 @@ class SsoHandler:
                     attributes,
                     auth_provider_id,
                     remote_user_id,
-                    request.get_user_agent(""),
+                    get_request_user_agent(request),
                     request.getClientIP(),
                 )
 
@@ -545,19 +589,45 @@ class SsoHandler:
             auth_provider_id, remote_user_id,
         )
 
+        user_id_to_verify = await self._auth_handler.get_session_data(
+            ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
         if not user_id:
             logger.warning(
                 "Remote user %s/%s has not previously logged in here: UIA will fail",
                 auth_provider_id,
                 remote_user_id,
             )
-            # Let the UIA flow handle this the same as if they presented creds for a
-            # different user.
-            user_id = ""
+        elif user_id != user_id_to_verify:
+            logger.warning(
+                "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+                user_id,
+            )
+        else:
+            # success!
+            # Mark the stage of the authentication as successful.
+            await self._store.mark_ui_auth_stage_complete(
+                ui_auth_session_id, LoginType.SSO, user_id
+            )
+
+            # Render the HTML confirmation page and return.
+            html = self._sso_auth_success_template
+            respond_with_html(request, 200, html)
+            return
+
+        # the user_id didn't match: mark the stage of the authentication as unsuccessful
+        await self._store.mark_ui_auth_stage_complete(
+            ui_auth_session_id, LoginType.SSO, ""
+        )
 
-        await self._auth_handler.complete_sso_ui_auth(
-            user_id, ui_auth_session_id, request
+        # render an error page.
+        html = self._bad_user_template.render(
+            server_name=self._server_name, user_id_to_verify=user_id_to_verify,
         )
+        respond_with_html(request, 200, html)
 
     async def check_username_availability(
         self, localpart: str, session_id: str,
@@ -628,7 +698,7 @@ class SsoHandler:
             attributes,
             session.auth_provider_id,
             session.remote_user_id,
-            request.get_user_agent(""),
+            get_request_user_agent(request),
             request.getClientIP(),
         )
 
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
 """
 
 from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS  # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+    """Constants for use with AuthHandler.set_session_data"""
+
+    # used during registration and password reset to store a hashed copy of the
+    # password, so that the client does not need to submit it each time.
+    PASSWORD_HASH = "password_hash"
+
+    # used during registration to store the mxid of the registered user
+    REGISTERED_USER_ID = "registered_user_id"
+
+    # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+    # for.
+    REQUEST_USER_ID = "request_user_id"
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 59b01b812c..4bc3cb53f0 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -17,6 +17,7 @@ import re
 
 from twisted.internet import task
 from twisted.web.client import FileBodyProducer
+from twisted.web.iweb import IRequest
 
 from synapse.api.errors import SynapseError
 
@@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer):
             FileBodyProducer.stopProducing(self)
         except task.TaskStopped:
             pass
+
+
+def get_request_user_agent(request: IRequest, default: str = "") -> str:
+    """Return the last User-Agent header, or the given default.
+    """
+    # There could be raw utf-8 bytes in the User-Agent header.
+
+    # N.B. if you don't do this, the logger explodes cryptically
+    # with maximum recursion trying to log errors about
+    # the charset problem.
+    # c.f. https://github.com/matrix-org/synapse/issues/3471
+
+    h = request.getHeader(b"User-Agent")
+    return h.decode("ascii", "replace") if h else default
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 29f40ddf5f..37ccf5ab98 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -32,7 +32,7 @@ from typing import (
 
 import treq
 from canonicaljson import encode_canonical_json
-from netaddr import IPAddress, IPSet
+from netaddr import AddrFormatError, IPAddress, IPSet
 from prometheus_client import Counter
 from zope.interface import implementer, provider
 
@@ -261,16 +261,16 @@ class BlacklistingAgentWrapper(Agent):
 
         try:
             ip_address = IPAddress(h.hostname)
-
+        except AddrFormatError:
+            # Not an IP
+            pass
+        else:
             if check_against_blacklist(
                 ip_address, self._ip_whitelist, self._ip_blacklist
             ):
                 logger.info("Blocking access to %s due to blacklist" % (ip_address,))
                 e = SynapseError(403, "IP address blocked by IP blacklist entry")
                 return defer.fail(Failure(e))
-        except Exception:
-            # Not an IP
-            pass
 
         return self._agent.request(
             method, uri, headers=headers, bodyProducer=bodyProducer
@@ -341,6 +341,7 @@ class SimpleHttpClient:
 
         self.agent = ProxyAgent(
             self.reactor,
+            hs.get_reactor(),
             connectTimeout=15,
             contextFactory=self.hs.get_http_client_context_factory(),
             pool=pool,
@@ -723,7 +724,7 @@ class SimpleHttpClient:
                 read_body_with_max_size(response, output_stream, max_size)
             )
         except BodyExceededMaxSize:
-            SynapseError(
+            raise SynapseError(
                 502,
                 "Requested file is too large > %r bytes" % (max_size,),
                 Codes.TOO_LARGE,
@@ -765,14 +766,24 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
         self.max_size = max_size
 
     def dataReceived(self, data: bytes) -> None:
+        # If the deferred was called, bail early.
+        if self.deferred.called:
+            return
+
         self.stream.write(data)
         self.length += len(data)
+        # The first time the maximum size is exceeded, error and cancel the
+        # connection. dataReceived might be called again if data was received
+        # in the meantime.
         if self.max_size is not None and self.length >= self.max_size:
             self.deferred.errback(BodyExceededMaxSize())
-            self.deferred = defer.Deferred()
             self.transport.loseConnection()
 
     def connectionLost(self, reason: Failure) -> None:
+        # If the maximum size was already exceeded, there's nothing to do.
+        if self.deferred.called:
+            return
+
         if reason.check(ResponseDone):
             self.deferred.callback(self.length)
         elif reason.check(PotentialDataLoss):
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
deleted file mode 100644
index 92a5b606c8..0000000000
--- a/synapse/http/endpoint.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import re
-
-logger = logging.getLogger(__name__)
-
-
-def parse_server_name(server_name):
-    """Split a server name into host/port parts.
-
-    Args:
-        server_name (str): server name to parse
-
-    Returns:
-        Tuple[str, int|None]: host/port parts.
-
-    Raises:
-        ValueError if the server name could not be parsed.
-    """
-    try:
-        if server_name[-1] == "]":
-            # ipv6 literal, hopefully
-            return server_name, None
-
-        domain_port = server_name.rsplit(":", 1)
-        domain = domain_port[0]
-        port = int(domain_port[1]) if domain_port[1:] else None
-        return domain, port
-    except Exception:
-        raise ValueError("Invalid server name '%s'" % server_name)
-
-
-VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
-
-
-def parse_and_validate_server_name(server_name):
-    """Split a server name into host/port parts and do some basic validation.
-
-    Args:
-        server_name (str): server name to parse
-
-    Returns:
-        Tuple[str, int|None]: host/port parts.
-
-    Raises:
-        ValueError if the server name could not be parsed.
-    """
-    host, port = parse_server_name(server_name)
-
-    # these tests don't need to be bulletproof as we'll find out soon enough
-    # if somebody is giving us invalid data. What we *do* need is to be sure
-    # that nobody is sneaking IP literals in that look like hostnames, etc.
-
-    # look for ipv6 literals
-    if host[0] == "[":
-        if host[-1] != "]":
-            raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
-        return host, port
-
-    # otherwise it should only be alphanumerics.
-    if not VALID_HOST_REGEX.match(host):
-        raise ValueError(
-            "Server name '%s' contains invalid characters" % (server_name,)
-        )
-
-    return host, port
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 3b756a7dc2..4c06a117d3 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -102,7 +102,6 @@ class MatrixFederationAgent:
                         pool=self._pool,
                         contextFactory=tls_client_options_factory,
                     ),
-                    self._reactor,
                     ip_blacklist=ip_blacklist,
                 ),
                 user_agent=self.user_agent,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b261e078c4..19293bf673 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -174,6 +174,16 @@ async def _handle_json_response(
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
         body = await make_deferred_yieldable(d)
+    except ValueError as e:
+        # The JSON content was invalid.
+        logger.warning(
+            "{%s} [%s] Failed to parse JSON response - %s %s",
+            request.txn_id,
+            request.destination,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        raise RequestSendFailed(e, can_retry=False) from e
     except defer.TimeoutError as e:
         logger.warning(
             "{%s} [%s] Timed out reading response - %s %s",
@@ -986,7 +996,7 @@ class MatrixFederationHttpClient:
             logger.warning(
                 "{%s} [%s] %s", request.txn_id, request.destination, msg,
             )
-            SynapseError(502, msg, Codes.TOO_LARGE)
+            raise SynapseError(502, msg, Codes.TOO_LARGE)
         except Exception as e:
             logger.warning(
                 "{%s} [%s] Error reading response: %s",
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index e32d3f43e0..b730d2c634 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -39,6 +39,10 @@ class ProxyAgent(_AgentBase):
         reactor: twisted reactor to place outgoing
             connections.
 
+        proxy_reactor: twisted reactor to use for connections to the proxy server
+                       reactor might have some blacklisting applied (i.e. for DNS queries),
+                       but we need unblocked access to the proxy.
+
         contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
             verification parameters of OpenSSL.  The default is to use a
             `BrowserLikePolicyForHTTPS`, so unless you have special
@@ -59,6 +63,7 @@ class ProxyAgent(_AgentBase):
     def __init__(
         self,
         reactor,
+        proxy_reactor=None,
         contextFactory=BrowserLikePolicyForHTTPS(),
         connectTimeout=None,
         bindAddress=None,
@@ -68,6 +73,11 @@ class ProxyAgent(_AgentBase):
     ):
         _AgentBase.__init__(self, reactor, pool)
 
+        if proxy_reactor is None:
+            self.proxy_reactor = reactor
+        else:
+            self.proxy_reactor = proxy_reactor
+
         self._endpoint_kwargs = {}
         if connectTimeout is not None:
             self._endpoint_kwargs["timeout"] = connectTimeout
@@ -75,11 +85,11 @@ class ProxyAgent(_AgentBase):
             self._endpoint_kwargs["bindAddress"] = bindAddress
 
         self.http_proxy_endpoint = _http_proxy_endpoint(
-            http_proxy, reactor, **self._endpoint_kwargs
+            http_proxy, self.proxy_reactor, **self._endpoint_kwargs
         )
 
         self.https_proxy_endpoint = _http_proxy_endpoint(
-            https_proxy, reactor, **self._endpoint_kwargs
+            https_proxy, self.proxy_reactor, **self._endpoint_kwargs
         )
 
         self._policy_for_https = contextFactory
@@ -137,7 +147,7 @@ class ProxyAgent(_AgentBase):
             request_path = uri
         elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
             endpoint = HTTPConnectProxyEndpoint(
-                self._reactor,
+                self.proxy_reactor,
                 self.https_proxy_endpoint,
                 parsed_uri.host,
                 parsed_uri.port,
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5a5790831b..12ec3f851f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
 
 from synapse.config.server import ListenerConfig
-from synapse.http import redact_uri
+from synapse.http import get_request_user_agent, redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.types import Requester
@@ -113,15 +113,6 @@ class SynapseRequest(Request):
             method = self.method.decode("ascii")
         return method
 
-    def get_user_agent(self, default: str) -> str:
-        """Return the last User-Agent header, or the given default.
-        """
-        user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
-        if user_agent is None:
-            return default
-
-        return user_agent.decode("ascii", "replace")
-
     def render(self, resrc):
         # this is called once a Resource has been found to serve the request; in our
         # case the Resource in question will normally be a JsonResource.
@@ -292,12 +283,7 @@ class SynapseRequest(Request):
             # and can see that we're doing something wrong.
             authenticated_entity = repr(self.requester)  # type: ignore[unreachable]
 
-        # ...or could be raw utf-8 bytes in the User-Agent header.
-        # N.B. if you don't do this, the logger explodes cryptically
-        # with maximum recursion trying to log errors about
-        # the charset problem.
-        # c.f. https://github.com/matrix-org/synapse/issues/3471
-        user_agent = self.get_user_agent("-")
+        user_agent = get_request_user_agent(self, "-")
 
         code = str(self.code)
         if not self.finished:
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a84a064c8d..dd527e807f 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -15,6 +15,7 @@
 
 from synapse.http.server import JsonResource
 from synapse.replication.http import (
+    account_data,
     devices,
     federation,
     login,
@@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
         presence.register_servlets(hs, self)
         membership.register_servlets(hs, self)
         streams.register_servlets(hs, self)
+        account_data.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 1492ac922c..288727a566 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -177,7 +177,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         @trace(opname="outgoing_replication_request")
         @outgoing_gauge.track_inprogress()
-        async def send_request(instance_name="master", **kwargs):
+        async def send_request(*, instance_name="master", **kwargs):
             if instance_name == local_instance_name:
                 raise Exception("Trying to send HTTP request to self")
             if instance_name == "master":
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
new file mode 100644
index 0000000000..52d32528ee
--- /dev/null
+++ b/synapse/replication/http/account_data.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+    """Add user account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_user_account_data/:user_id/:type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_user_account_data"
+    PATH_ARGS = ("user_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_for_user(
+            user_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+    """Add room account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_room_account_data"
+    PATH_ARGS = ("user_id", "room_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_to_room(
+            user_id, room_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddTagRestServlet(ReplicationEndpoint):
+    """Add tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_tag"
+    PATH_ARGS = ("user_id", "room_id", "tag")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_tag_to_room(
+            user_id, room_id, tag, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
+    """Remove tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
+
+        {}
+
+    """
+
+    NAME = "remove_tag"
+    PATH_ARGS = (
+        "user_id",
+        "room_id",
+        "tag",
+    )
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag):
+
+        return {}
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+def register_servlets(hs, http_server):
+    ReplicationUserAccountDataRestServlet(hs).register(http_server)
+    ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+    ReplicationAddTagRestServlet(hs).register(http_server)
+    ReplicationRemoveTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d0089fe06c..693c9ab901 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 4268565fc8..21afe5f155 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,47 +15,9 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.databases.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = SlavedIdTracker(
-            db_conn,
-            "account_data",
-            "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self):
-        return self._account_data_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                if not row.room_id:
-                    self.get_global_account_data_by_type_for_user.invalidate(
-                        (row.data_type, row.user_id)
-                    )
-                self.get_account_data_for_user.invalidate((row.user_id,))
-                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
-                self.get_account_data_for_room_and_type.invalidate(
-                    (row.user_id, row.room_id, row.data_type)
-                )
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6195917376..3dfdd9961d 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,43 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = SlavedIdTracker(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
-
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
-        self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
-        self.get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
-        )
-        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
-        self.get_receipts_for_room.invalidate((room_id, receipt_type))
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ReceiptsStream.NAME:
-            self._receipts_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.invalidate_caches_for_receipt(
-                    row.room_id, row.receipt_type, row.user_id
-                )
-                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
-
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1f89249475..317796d5e0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import (
 from synapse.replication.tcp.protocol import AbstractConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
+    AccountDataStream,
     BackfillStream,
     CachesStream,
     EventsStream,
     FederationStream,
+    ReceiptsStream,
     Stream,
+    TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
 )
@@ -132,6 +135,22 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+                # Only add AccountDataStream and TagAccountDataStream as a source on the
+                # instance in charge of account_data persistence.
+                if hs.get_instance_name() in hs.config.worker.writers.account_data:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
+            if isinstance(stream, ReceiptsStream):
+                # Only add ReceiptsStream as a source on the instance in charge of
+                # receipts.
+                if hs.get_instance_name() in hs.config.worker.writers.receipts:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
new file mode 100644
index 0000000000..3611191bf9
--- /dev/null
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -0,0 +1,18 @@
+<html>
+<head>
+    <title>Authentication Failed</title>
+</head>
+    <body>
+        <div>
+            <p>
+                We were unable to validate your <tt>{{server_name | e}}</tt> account via
+                single-sign-on (SSO), because the SSO Identity Provider returned
+                different details than when you logged in.
+            </p>
+            <p>
+                Try the operation again, and ensure that you use the same details on
+                the Identity Provider as when you log into your account.
+            </p>
+        </div>
+    </body>
+</html>
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index f53c9cd679..5b38481012 100644
--- a/synapse/res/templates/sso_login_idp_picker.html
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -17,6 +17,9 @@
                     <li>
                         <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
                         <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
+{% if p.idp_icon %}
+                        <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/>
+{% endif %}
                     </li>
 {% endfor %}
                 </ul>
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
 
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
     assert_requester_is_admin,
     assert_user_is_admin,
 )
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
         admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, room_id: str):
+    async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
 
     PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, user_id: str):
+    async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
         "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, server_name: str, media_id: str):
+    async def on_POST(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
         return 200, {}
 
 
+class ProtectMediaByID(RestServlet):
+    """Protect local media from being quarantined.
+    """
+
+    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        logging.info("Protecting local media by ID: %s", media_id)
+
+        # Quarantine this media id
+        await self.store.mark_local_media_as_safe(media_id)
+
+        return 200, {}
+
+
 class ListMediaInRoom(RestServlet):
     """Lists all of the media in a given room.
     """
 
     PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
 class PurgeMediaCacheRestServlet(RestServlet):
     PATTERNS = admin_patterns("/purge_media_cache")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.media_repository = hs.get_media_repository()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_DELETE(self, request, server_name: str, media_id: str):
+    async def on_DELETE(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_POST(self, request, server_name: str):
+    async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
         return 200, {"deleted_media": deleted_media, "total": total}
 
 
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
     """
     Media repo specific APIs.
     """
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
     QuarantineMediaInRoom(hs).register(http_server)
     QuarantineMediaByID(hs).register(http_server)
     QuarantineMediaByUser(hs).register(http_server)
+    ProtectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
     DeleteMediaByID(hs).register(http_server)
     DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 6658c2da56..86198bab30 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
     The parameter `deactivated` can be used to include deactivated users.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
+
+        if start < 0:
+            raise SynapseError(
+                400,
+                "Query parameter from must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                400,
+                "Query parameter limit must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
         user_id = parse_string(request, "user_id", default=None)
         name = parse_string(request, "name", default=None)
         guests = parse_boolean(request, "guests", default=True)
@@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
             start, limit, user_id, name, guests, deactivated
         )
         ret = {"users": users, "total": total}
-        if len(users) >= limit:
+        if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
         return 200, ret
@@ -244,7 +259,7 @@ class UserRestServletV2(RestServlet):
 
                 if deactivate and not user["deactivated"]:
                     await self.deactivate_account_handler.deactivate_account(
-                        target_user.to_string(), False
+                        target_user.to_string(), False, requester, by_admin=True
                     )
                 elif not deactivate and user["deactivated"]:
                     if "password" not in body:
@@ -486,12 +501,22 @@ class WhoisRestServlet(RestServlet):
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
+        self.store = hs.get_datastore()
+
+    async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        if not self.is_mine(UserID.from_string(target_user_id)):
+            raise SynapseError(400, "Can only deactivate local users")
+
+        if not await self.store.get_user_by_id(target_user_id):
+            raise NotFoundError("User not found")
 
-    async def on_POST(self, request, target_user_id):
-        await assert_requester_is_admin(self.auth, request)
         body = parse_json_object_from_request(request, allow_empty_body=True)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -501,10 +526,8 @@ class DeactivateAccountRestServlet(RestServlet):
                 Codes.BAD_JSON,
             )
 
-        UserID.from_string(target_user_id)
-
         result = await self._deactivate_account_handler.deactivate_account(
-            target_user_id, erase
+            target_user_id, erase, requester, by_admin=True
         )
         if result:
             id_server_unbind_result = "success"
@@ -714,13 +737,6 @@ class UserMembershipRestServlet(RestServlet):
     async def on_GET(self, request, user_id):
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only lookup local users")
-
-        user = await self.store.get_user_by_id(user_id)
-        if user is None:
-            raise NotFoundError("Unknown user")
-
         room_ids = await self.store.get_rooms_for_user(user_id)
         ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
         return 200, ret
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 5647e8c577..f95627ee61 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.util import json_decoder
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
 if TYPE_CHECKING:
     import synapse.server
@@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
             # provided.
             if server:
                 raise e
-            else:
-                pass
 
         limit = parse_integer(request, "limit", 0)
         since_token = parse_string(request, "since", None)
@@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server, limit=limit, since_token=since_token
@@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server,
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d837bde1d6..65e68d641b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING
 from urllib.parse import urlparse
 
-if TYPE_CHECKING:
-    from synapse.app.homeserver import HomeServer
-
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
 )
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
@@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             try:
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
-                    requester,
-                    request,
-                    body,
-                    self.hs.get_ip_from_request(request),
-                    "modify your account password",
+                    requester, request, body, "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
                 # The user needs to provide more steps to complete auth, but
@@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
@@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet):
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
                     body,
-                    self.hs.get_ip_from_request(request),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
 
@@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet):
             password_hash = await self.auth_handler.hash(new_password)
         elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
         else:
             # UI validation was skipped, but the request did not include a new
@@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet):
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase
+                requester.user.to_string(), erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "deactivate your account",
+            requester, request, body, "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(), erase, id_server=body.get("id_server")
+            requester.user.to_string(),
+            erase,
+            requester,
+            id_server=body.get("id_server"),
         )
         if result:
             id_server_unbind_result = "success"
@@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a third-party identifier to your account",
+            requester, request, body, "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_account_data_for_user(
-            user_id, account_data_type, body
-        )
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_account_data_for_user(user_id, account_data_type, body)
 
         return 200, {}
 
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
                 " Use /rooms/!roomId:server.name/read_markers",
             )
 
-        max_id = await self.store.add_account_data_to_room(
+        await self.handler.add_account_data_to_room(
             user_id, room_id, account_data_type, body
         )
 
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
         return 200, {}
 
     async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 9b9514632f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -19,7 +19,6 @@ from typing import TYPE_CHECKING
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
 from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.handlers.sso import SsoIdentityProvider
 from synapse.http.server import respond_with_html
 from synapse.http.servlet import RestServlet, parse_string
 
@@ -46,22 +45,6 @@ class AuthRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
-
-        # SSO configuration.
-        self._cas_enabled = hs.config.cas_enabled
-        if self._cas_enabled:
-            self._cas_handler = hs.get_cas_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
-        self._oidc_enabled = hs.config.oidc_enabled
-        if self._oidc_enabled:
-            self._oidc_handler = hs.get_oidc_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
         self.success_template = hs.config.fallback_success_template
@@ -90,21 +73,7 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-
-            if self._cas_enabled:
-                sso_auth_provider = self._cas_handler  # type: SsoIdentityProvider
-            elif self._saml_enabled:
-                sso_auth_provider = self._saml_handler
-            elif self._oidc_enabled:
-                sso_auth_provider = self._oidc_handler
-            else:
-                raise SynapseError(400, "Homeserver not configured for SSO.")
-
-            sso_redirect_url = await sso_auth_provider.handle_redirect_request(
-                request, None, session
-            )
-
-            html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+            html = await self.auth_handler.start_sso_ui_auth(request, session)
 
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -128,7 +97,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"response": response, "session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+                LoginType.RECAPTCHA, authdict, request.getClientIP()
             )
 
             if success:
@@ -144,7 +113,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+                LoginType.TERMS, authdict, request.getClientIP()
             )
 
             if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove device(s) from your account",
+            requester, request, body, "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove a device from your account",
+            requester, request, body, "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a device signing key to your account",
+            requester, request, body, "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 6b5a1b7109..b093183e79 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.config.registration import RegistrationConfig
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
             )
 
-        ip = self.hs.get_ip_from_request(request)
+        ip = request.getClientIP()
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
 
@@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
             # user here. We carry on and go through the auth checks though,
             # for paranoia.
             registered_user_id = await self.auth_handler.get_session_data(
-                session_id, "registered_user_id", None
+                session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
             )
             # Extract the previously-hashed password from the session.
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
 
         # Ensure that the username is valid.
@@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet):
         # not this will raise a user-interactive auth error.
         try:
             auth_result, params, session_id = await self.auth_handler.check_ui_auth(
-                self._registration_flows,
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "register a new account",
+                self._registration_flows, request, body, "register a new account",
             )
         except InteractiveAuthIncompleteError as e:
             # The user needs to provide more steps to complete auth.
@@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet):
             if not password_hash and password:
                 password_hash = await self.auth_handler.hash(password)
                 await self.auth_handler.set_session_data(
-                    e.session_id, "password_hash", password_hash
+                    e.session_id,
+                    UIAuthSessionDataConstants.PASSWORD_HASH,
+                    password_hash,
                 )
             raise
 
@@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet):
             # Remember that the user account has been registered (and the user
             # ID it was registered with, since it might not have been specified).
             await self.auth_handler.set_session_data(
-                session_id, "registered_user_id", registered_user_id
+                session_id,
+                UIAuthSessionDataConstants.REGISTERED_USER_ID,
+                registered_user_id,
             )
 
             registered = True
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
     def __init__(self, hs):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, tag):
         requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_tag_to_room(user_id, room_id, tag, body)
 
         return 200, {}
 
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
 
-        max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.remove_tag_from_room(user_id, room_id, tag)
 
         return 200, {}
 
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 47c2b44bff..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,10 +17,11 @@
 import logging
 import os
 import urllib
-from typing import Awaitable
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple
 
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError, cs_error
 from synapse.http.server import finish_request, respond_with_json
@@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
 ]
 
 
-def parse_media_id(request):
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
     try:
         # This allows users to append e.g. /test.png to the URL. Useful for
         # clients that parse the URL to see content type.
@@ -69,7 +70,7 @@ def parse_media_id(request):
         )
 
 
-def respond_404(request):
+def respond_404(request: Request) -> None:
     respond_with_json(
         request,
         404,
@@ -79,8 +80,12 @@ def respond_404(request):
 
 
 async def respond_with_file(
-    request, media_type, file_path, file_size=None, upload_name=None
-):
+    request: Request,
+    media_type: str,
+    file_path: str,
+    file_size: Optional[int] = None,
+    upload_name: Optional[str] = None,
+) -> None:
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
@@ -98,15 +103,20 @@ async def respond_with_file(
         respond_404(request)
 
 
-def add_file_headers(request, media_type, file_size, upload_name):
+def add_file_headers(
+    request: Request,
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str],
+) -> None:
     """Adds the correct response headers in preparation for responding with the
     media.
 
     Args:
-        request (twisted.web.http.Request)
-        media_type (str): The media/content type.
-        file_size (int): Size in bytes of the media, if known.
-        upload_name (str): The name of the requested file, if any.
+        request
+        media_type: The media/content type.
+        file_size: Size in bytes of the media, if known.
+        upload_name: The name of the requested file, if any.
     """
 
     def _quote(x):
@@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name):
     # select private. don't bother setting Expires as all our
     # clients are smart enough to be happy with Cache-Control
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
-    request.setHeader(b"Content-Length", b"%d" % (file_size,))
+    if file_size is not None:
+        request.setHeader(b"Content-Length", b"%d" % (file_size,))
 
     # Tell web crawlers to not index, archive, or follow links in media. This
     # should help to prevent things in the media repo from showing up in web
@@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
 }
 
 
-def _can_encode_filename_as_token(x):
+def _can_encode_filename_as_token(x: str) -> bool:
     for c in x:
         # from RFC2616:
         #
@@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
 
 
 async def respond_with_responder(
-    request, responder, media_type, file_size, upload_name=None
-):
+    request: Request,
+    responder: "Optional[Responder]",
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str] = None,
+) -> None:
     """Responds to the request with given responder. If responder is None then
     returns 404.
 
     Args:
-        request (twisted.web.http.Request)
-        responder (Responder|None)
-        media_type (str): The media/content type.
-        file_size (int|None): Size in bytes of the media. If not known it should be None
-        upload_name (str|None): The name of the requested file, if any.
+        request
+        responder
+        media_type: The media/content type.
+        file_size: Size in bytes of the media. If not known it should be None
+        upload_name: The name of the requested file, if any.
     """
     if request._disconnected:
         logger.warning(
@@ -285,6 +300,7 @@ class FileInfo:
         thumbnail_height (int)
         thumbnail_method (str)
         thumbnail_type (str): Content type of thumbnail, e.g. image/png
+        thumbnail_length (int): The size of the media file, in bytes.
     """
 
     def __init__(
@@ -297,6 +313,7 @@ class FileInfo:
         thumbnail_height=None,
         thumbnail_method=None,
         thumbnail_type=None,
+        thumbnail_length=None,
     ):
         self.server_name = server_name
         self.file_id = file_id
@@ -306,24 +323,25 @@ class FileInfo:
         self.thumbnail_height = thumbnail_height
         self.thumbnail_method = thumbnail_method
         self.thumbnail_type = thumbnail_type
+        self.thumbnail_length = thumbnail_length
 
 
-def get_filename_from_headers(headers):
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
     """
     Get the filename of the downloaded file by inspecting the
     Content-Disposition HTTP header.
 
     Args:
-        headers (dict[bytes, list[bytes]]): The HTTP request headers.
+        headers: The HTTP request headers.
 
     Returns:
-        A Unicode string of the filename, or None.
+        The filename, or None.
     """
     content_disposition = headers.get(b"Content-Disposition", [b""])
 
     # No header, bail out.
     if not content_disposition[0]:
-        return
+        return None
 
     _, params = _parse_header(content_disposition[0])
 
@@ -356,17 +374,16 @@ def get_filename_from_headers(headers):
     return upload_name
 
 
-def _parse_header(line):
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
     """Parse a Content-type like header.
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        line (bytes): header to be parsed
+        line: header to be parsed
 
     Returns:
-        Tuple[bytes, dict[bytes, bytes]]:
-            the main content-type, followed by the parameter dictionary
+        The main content-type, followed by the parameter dictionary
     """
     parts = _parseparam(b";" + line)
     key = next(parts)
@@ -386,16 +403,16 @@ def _parse_header(line):
     return key, pdict
 
 
-def _parseparam(s):
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
     """Generator which splits the input on ;, respecting double-quoted sequences
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        s (bytes): header to be parsed
+        s: header to be parsed
 
     Returns:
-        Iterable[bytes]: the split input
+        The split input
     """
     while s[:1] == b";":
         s = s[1:]
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 68dd2a1c8a..4e4c6971f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 Will Hunt <will@half-shot.uk>
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,22 +15,29 @@
 # limitations under the License.
 #
 
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 
 class MediaConfigResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         config = hs.get_config()
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
         self.limits_dict = {"m.upload.size": config.max_upload_size}
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         await self.auth.get_user_by_req(request)
         respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d3d8457303..3ed219ae43 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,24 +14,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
-import synapse.http.servlet
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.servlet import parse_boolean
 
 from ._base import parse_media_id, respond_404
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class DownloadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
         self.media_repo = media_repo
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         request.setHeader(
             b"Content-Security-Policy",
@@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
         if server_name == self.server_name:
             await self.media_repo.get_local_media(request, media_id, name)
         else:
-            allow_remote = synapse.http.servlet.parse_boolean(
-                request, "allow_remote", default=True
-            )
+            allow_remote = parse_boolean(request, "allow_remote", default=True)
             if not allow_remote:
                 logger.info(
                     "Rejecting request for remote media %s/%s due to allow_remote",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 9e079f672f..7792f26e78 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,11 +17,12 @@
 import functools
 import os
 import re
+from typing import Callable, List
 
 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
 
 
-def _wrap_in_base_path(func):
+def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
     """Takes a function that returns a relative path and turns it into an
     absolute path based on the location of the primary media store
     """
@@ -41,12 +43,18 @@ class MediaFilePaths:
     to write to the backup media store (when one is configured)
     """
 
-    def __init__(self, primary_base_path):
+    def __init__(self, primary_base_path: str):
         self.base_path = primary_base_path
 
     def default_thumbnail_rel(
-        self, default_top_level, default_sub_type, width, height, content_type, method
-    ):
+        self,
+        default_top_level: str,
+        default_sub_type: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -55,12 +63,14 @@ class MediaFilePaths:
 
     default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
 
-    def local_media_filepath_rel(self, media_id):
+    def local_media_filepath_rel(self, media_id: str) -> str:
         return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
 
     local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
 
-    def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def local_media_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -86,7 +96,7 @@ class MediaFilePaths:
             media_id[4:],
         )
 
-    def remote_media_filepath_rel(self, server_name, file_id):
+    def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
         )
@@ -94,8 +104,14 @@ class MediaFilePaths:
     remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
 
     def remote_media_thumbnail_rel(
-        self, server_name, file_id, width, height, content_type, method
-    ):
+        self,
+        server_name: str,
+        file_id: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -113,7 +129,7 @@ class MediaFilePaths:
     # Should be removed after some time, when most of the thumbnails are stored
     # using the new path.
     def remote_media_thumbnail_rel_legacy(
-        self, server_name, file_id, width, height, content_type
+        self, server_name: str, file_id: str, width: int, height: int, content_type: str
     ):
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@@ -126,7 +142,7 @@ class MediaFilePaths:
             file_name,
         )
 
-    def remote_media_thumbnail_dir(self, server_name, file_id):
+    def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             self.base_path,
             "remote_thumbnail",
@@ -136,7 +152,7 @@ class MediaFilePaths:
             file_id[4:],
         )
 
-    def url_cache_filepath_rel(self, media_id):
+    def url_cache_filepath_rel(self, media_id: str) -> str:
         if NEW_FORMAT_ID_RE.match(media_id):
             # Media id is of the form <DATE><RANDOM_STRING>
             # E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -146,7 +162,7 @@ class MediaFilePaths:
 
     url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
 
-    def url_cache_filepath_dirs_to_delete(self, media_id):
+    def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id file"
         if NEW_FORMAT_ID_RE.match(media_id):
             return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@@ -156,7 +172,9 @@ class MediaFilePaths:
                 os.path.join(self.base_path, "url_cache", media_id[0:2]),
             ]
 
-    def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def url_cache_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -178,7 +196,7 @@ class MediaFilePaths:
 
     url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
 
-    def url_cache_thumbnail_directory(self, media_id):
+    def url_cache_thumbnail_directory(self, media_id: str) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -195,7 +213,7 @@ class MediaFilePaths:
                 media_id[4:],
             )
 
-    def url_cache_thumbnail_dirs_to_delete(self, media_id):
+    def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id thumbnails"
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 83beb02b05..4c9946a616 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import errno
 import logging
 import os
 import shutil
-from typing import IO, Dict, List, Optional, Tuple
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 import twisted.internet.error
 import twisted.web.http
@@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
 from .thumbnailer import Thumbnailer, ThumbnailError
 from .upload_resource import UploadResource
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.client = hs.get_federation_http_client()
@@ -73,16 +76,16 @@ class MediaRepository:
         self.max_upload_size = hs.config.max_upload_size
         self.max_image_pixels = hs.config.max_image_pixels
 
-        self.primary_base_path = hs.config.media_store_path
-        self.filepaths = MediaFilePaths(self.primary_base_path)
+        self.primary_base_path = hs.config.media_store_path  # type: str
+        self.filepaths = MediaFilePaths(self.primary_base_path)  # type: MediaFilePaths
 
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
-        self.recently_accessed_remotes = set()
-        self.recently_accessed_locals = set()
+        self.recently_accessed_remotes = set()  # type: Set[Tuple[str, str]]
+        self.recently_accessed_locals = set()  # type: Set[str]
 
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
@@ -113,7 +116,7 @@ class MediaRepository:
             "update_recently_accessed_media", self._update_recently_accessed
         )
 
-    async def _update_recently_accessed(self):
+    async def _update_recently_accessed(self) -> None:
         remote_media = self.recently_accessed_remotes
         self.recently_accessed_remotes = set()
 
@@ -124,12 +127,12 @@ class MediaRepository:
             local_media, remote_media, self.clock.time_msec()
         )
 
-    def mark_recently_accessed(self, server_name, media_id):
+    def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
         """Mark the given media as recently accessed.
 
         Args:
-            server_name (str|None): Origin server of media, or None if local
-            media_id (str): The media ID of the content
+            server_name: Origin server of media, or None if local
+            media_id: The media ID of the content
         """
         if server_name:
             self.recently_accessed_remotes.add((server_name, media_id))
@@ -459,7 +462,14 @@ class MediaRepository:
     def _get_thumbnail_requirements(self, media_type):
         return self.thumbnail_requirements.get(media_type, ())
 
-    def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
+    def _generate_thumbnail(
+        self,
+        thumbnailer: Thumbnailer,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[BytesIO]:
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
@@ -470,22 +480,20 @@ class MediaRepository:
                 m_height,
                 self.max_image_pixels,
             )
-            return
+            return None
 
         if thumbnailer.transpose_method is not None:
             m_width, m_height = thumbnailer.transpose()
 
         if t_method == "crop":
-            t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
+            return thumbnailer.crop(t_width, t_height, t_type)
         elif t_method == "scale":
             t_width, t_height = thumbnailer.aspect(t_width, t_height)
             t_width = min(m_width, t_width)
             t_height = min(m_height, t_height)
-            t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
-        else:
-            t_byte_source = None
+            return thumbnailer.scale(t_width, t_height, t_type)
 
-        return t_byte_source
+        return None
 
     async def generate_local_exact_thumbnail(
         self,
@@ -776,7 +784,7 @@ class MediaRepository:
 
         return {"width": m_width, "height": m_height}
 
-    async def delete_old_remote_media(self, before_ts):
+    async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
         old_media = await self.store.get_remote_media_before(before_ts)
 
         deleted = 0
@@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
     within a given rectangle.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         # If we're not configured to use it, raise if we somehow got here.
         if not hs.config.can_load_media_repo:
             raise ConfigError("Synapse is not configured to use a media repo.")
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 268e0c8f50..89cdd605aa 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vecotr Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@ import os
 import shutil
 from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
 
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@@ -270,7 +272,7 @@ class MediaStorage:
         return self.filepaths.local_media_filepath_rel(file_info.file_id)
 
 
-def _write_file_synchronously(source, dest):
+def _write_file_synchronously(source: IO, dest: IO) -> None:
     """Write `source` to the file like `dest` synchronously. Should be called
     from a thread.
 
@@ -286,14 +288,14 @@ class FileResponder(Responder):
     """Wraps an open file that can be sent to a request.
 
     Args:
-        open_file (file): A file like object to be streamed ot the client,
+        open_file: A file like object to be streamed ot the client,
             is closed when finished streaming.
     """
 
-    def __init__(self, open_file):
+    def __init__(self, open_file: IO):
         self.open_file = open_file
 
-    def write_to_consumer(self, consumer):
+    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
         return make_deferred_yieldable(
             FileSender().beginFileTransfer(self.open_file, consumer)
         )
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 1082389d9b..a632099167 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import datetime
 import errno
 import fnmatch
@@ -23,12 +23,13 @@ import re
 import shutil
 import sys
 import traceback
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
 from urllib import parse as urlparse
 
 import attr
 
 from twisted.internet.error import DNSLookupError
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.client import SimpleHttpClient
@@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
 
 from ._base import FileInfo
 
+if TYPE_CHECKING:
+    from lxml import etree
+
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@@ -119,7 +127,12 @@ class OEmbedError(Exception):
 class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.auth = hs.get_auth()
@@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
                 self._start_expire_url_cache_data, 10 * 1000
             )
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         request.setHeader(b"Allow", b"OPTIONS, GET")
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
 
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
@@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
             raise OEmbedError() from e
 
-    async def _download_url(self, url: str, user):
+    async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             "expire_url_cache_data", self._expire_url_cache_data
         )
 
-    async def _expire_url_cache_data(self):
+    async def _expire_url_cache_data(self) -> None:
         """Clean up expired url cache content, media and thumbnails.
         """
         # TODO: Delete from backup media store
@@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
+def decode_and_calc_og(
+    body: bytes, media_uri: str, request_encoding: Optional[str] = None
+) -> Dict[str, Optional[str]]:
     # If there's no body, nothing useful is going to be found.
     if not body:
         return {}
@@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
     return og
 
 
-def _calc_og(tree, media_uri):
+def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
     # suck our tree into lxml and define our OG response.
 
     # if we see any image URLs in the OG response, then spider them
@@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
                 for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
             )
             og["og:description"] = summarize_paragraphs(text_nodes)
-    else:
+    elif og["og:description"]:
+        # This must be a non-empty string at this point.
+        assert isinstance(og["og:description"], str)
         og["og:description"] = summarize_paragraphs([og["og:description"]])
 
     # TODO: delete the url downloads to stop diskfilling,
@@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
     return og
 
 
-def _iterate_over_text(tree, *tags_to_ignore):
+def _iterate_over_text(
+    tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
     """Iterate over the tree returning text nodes in a depth first fashion,
     skipping text nodes inside certain tags.
     """
@@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
             )
 
 
-def _rebase_url(url, base):
-    base = list(urlparse.urlparse(base))
-    url = list(urlparse.urlparse(url))
-    if not url[0]:  # fix up schema
-        url[0] = base[0] or "http"
-    if not url[1]:  # fix up hostname
-        url[1] = base[1]
-        if not url[2].startswith("/"):
-            url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
-    return urlparse.urlunparse(url)
+def _rebase_url(url: str, base: str) -> str:
+    base_parts = list(urlparse.urlparse(base))
+    url_parts = list(urlparse.urlparse(url))
+    if not url_parts[0]:  # fix up schema
+        url_parts[0] = base_parts[0] or "http"
+    if not url_parts[1]:  # fix up hostname
+        url_parts[1] = base_parts[1]
+        if not url_parts[2].startswith("/"):
+            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+    return urlparse.urlunparse(url_parts)
 
 
-def _is_media(content_type):
-    if content_type.lower().startswith("image/"):
-        return True
+def _is_media(content_type: str) -> bool:
+    return content_type.lower().startswith("image/")
 
 
-def _is_html(content_type):
+def _is_html(content_type: str) -> bool:
     content_type = content_type.lower()
-    if content_type.startswith("text/html") or content_type.startswith(
+    return content_type.startswith("text/html") or content_type.startswith(
         "application/xhtml"
-    ):
-        return True
+    )
 
 
-def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+def summarize_paragraphs(
+    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
     # Try to get a summary of between 200 and 500 words, respecting
     # first paragraph and then word boundaries.
     # TODO: Respect sentences?
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 67f67efde7..e92006faa9 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,10 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
 import os
 import shutil
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
@@ -27,13 +28,17 @@ from .media_storage import FileResponder
 
 logger = logging.getLogger(__name__)
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
-class StorageProvider:
+
+class StorageProvider(metaclass=abc.ABCMeta):
     """A storage provider is a service that can store uploaded media and
     retrieve them.
     """
 
-    async def store_file(self, path: str, file_info: FileInfo):
+    @abc.abstractmethod
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """Store the file described by file_info. The actual contents can be
         retrieved by reading the file in file_info.upload_path.
 
@@ -42,6 +47,7 @@ class StorageProvider:
             file_info: The metadata of the file.
         """
 
+    @abc.abstractmethod
     async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """Attempt to fetch the file described by file_info and stream it
         into writer.
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
         self.store_synchronous = store_synchronous
         self.store_remote = store_remote
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "StorageProviderWrapper[%s]" % (self.backend,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         if not file_info.server_name and not self.store_local:
             return None
 
@@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
         if self.store_synchronous:
             # store_file is supposed to return an Awaitable, but guard
             # against improper implementations.
-            return await maybe_awaitable(self.backend.store_file(path, file_info))
+            await maybe_awaitable(self.backend.store_file(path, file_info))  # type: ignore
         else:
             # TODO: Handle errors.
             async def store():
@@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
                     logger.exception("Error storing file")
 
             run_in_background(store)
-            return None
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         # store_file is supposed to return an Awaitable, but guard
         # against improper implementations.
         return await maybe_awaitable(self.backend.fetch(path, file_info))
@@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
     """A storage provider that stores files in a directory on a filesystem.
 
     Args:
-        hs (HomeServer)
+        hs
         config: The config returned by `parse_config`.
     """
 
-    def __init__(self, hs, config):
+    def __init__(self, hs: "HomeServer", config: str):
         self.hs = hs
         self.cache_directory = hs.config.media_store_path
         self.base_directory = config
@@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
     def __str__(self):
         return "FileStorageProviderBackend[%s]" % (self.base_directory,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """See StorageProvider.store_file"""
 
         primary_fname = os.path.join(self.cache_directory, path)
@@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
         if not os.path.exists(dirname):
             os.makedirs(dirname)
 
-        return await defer_to_thread(
+        await defer_to_thread(
             self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
         )
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """See StorageProvider.fetch"""
 
         backup_fname = os.path.join(self.base_directory, path)
         if os.path.isfile(backup_fname):
             return FileResponder(open(backup_fname, "rb"))
 
+        return None
+
     @staticmethod
-    def parse_config(config):
+    def parse_config(config: dict) -> str:
         """Called on startup to parse config supplied. This should parse
         the config and raise if there is a problem.
 
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 30421b663a..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +16,14 @@
 
 
 import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from twisted.web.http import Request
 
 from synapse.api.errors import SynapseError
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.rest.media.v1.media_storage import MediaStorage
 
 from ._base import (
     FileInfo,
@@ -28,13 +33,22 @@ from ._base import (
     respond_with_responder,
 )
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class ThumbnailResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.store = hs.get_datastore()
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         server_name, media_id, _ = parse_media_id(request)
         width = parse_integer(request, "width", required=True)
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
     async def _respond_local_thumbnail(
-        self, request, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -86,41 +106,27 @@ class ThumbnailResource(DirectServeJsonResource):
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
-        if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
-            )
-
-            file_info = FileInfo(
-                server_name=None,
-                file_id=media_id,
-                url_cache=media_info["url_cache"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
-
-            responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
-        else:
-            logger.info("Couldn't find any generated thumbnails")
-            respond_404(request)
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_id,
+            url_cache=media_info["url_cache"],
+            server_name=None,
+        )
 
     async def _select_or_generate_local_thumbnail(
         self,
-        request,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -178,14 +184,14 @@ class ThumbnailResource(DirectServeJsonResource):
 
     async def _select_or_generate_remote_thumbnail(
         self,
-        request,
-        server_name,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        server_name: str,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
@@ -239,8 +245,15 @@ class ThumbnailResource(DirectServeJsonResource):
             raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _respond_remote_thumbnail(
-        self, request, server_name, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        server_name: str,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
@@ -249,97 +262,185 @@ class ThumbnailResource(DirectServeJsonResource):
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_info["filesystem_id"],
+            url_cache=None,
+            server_name=server_name,
+        )
 
+    async def _select_and_respond_with_thumbnail(
+        self,
+        request: Request,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str] = None,
+        server_name: Optional[str] = None,
+    ) -> None:
+        """
+        Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            request: The incoming request.
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+        """
         if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
+            file_info = self._select_thumbnail(
+                desired_width,
+                desired_height,
+                desired_method,
+                desired_type,
+                thumbnail_infos,
+                file_id,
+                url_cache,
+                server_name,
             )
-            file_info = FileInfo(
-                server_name=server_name,
-                file_id=media_info["filesystem_id"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
+            if not file_info:
+                logger.info("Couldn't find a thumbnail matching the desired inputs")
+                respond_404(request)
+                return
 
             responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
+            await respond_with_responder(
+                request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+            )
         else:
             logger.info("Failed to find any generated thumbnails")
             respond_404(request)
 
     def _select_thumbnail(
         self,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-        thumbnail_infos,
-    ):
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str],
+        server_name: Optional[str],
+    ) -> Optional[FileInfo]:
+        """
+        Choose an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+
+        Returns:
+             The thumbnail which best matches the desired parameters.
+        """
+        desired_method = desired_method.lower()
+
+        # The chosen thumbnail.
+        thumbnail_info = None
+
         d_w = desired_width
         d_h = desired_height
 
-        if desired_method.lower() == "crop":
+        if desired_method == "crop":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             crop_info_list = []
+            # Other thumbnails.
             crop_info_list2 = []
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "crop":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
-                if t_method == "crop":
-                    aspect_quality = abs(d_w * t_h - d_h * t_w)
-                    min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
-                    size_quality = abs((d_w - t_w) * (d_h - t_h))
-                    type_quality = desired_type != info["thumbnail_type"]
-                    length_quality = info["thumbnail_length"]
-                    if t_w >= d_w or t_h >= d_h:
-                        crop_info_list.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                aspect_quality = abs(d_w * t_h - d_h * t_w)
+                min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+                size_quality = abs((d_w - t_w) * (d_h - t_h))
+                type_quality = desired_type != info["thumbnail_type"]
+                length_quality = info["thumbnail_length"]
+                if t_w >= d_w or t_h >= d_h:
+                    crop_info_list.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
-                    else:
-                        crop_info_list2.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                    )
+                else:
+                    crop_info_list2.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
+                    )
             if crop_info_list:
-                return min(crop_info_list)[-1]
-            else:
-                return min(crop_info_list2)[-1]
-        else:
+                thumbnail_info = min(crop_info_list)[-1]
+            elif crop_info_list2:
+                thumbnail_info = min(crop_info_list2)[-1]
+        elif desired_method == "scale":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             info_list = []
+            # Other thumbnails.
             info_list2 = []
+
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "scale":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
                 size_quality = abs((d_w - t_w) * (d_h - t_h))
                 type_quality = desired_type != info["thumbnail_type"]
                 length_quality = info["thumbnail_length"]
-                if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+                if t_w >= d_w or t_h >= d_h:
                     info_list.append((size_quality, type_quality, length_quality, info))
-                elif t_method == "scale":
+                else:
                     info_list2.append(
                         (size_quality, type_quality, length_quality, info)
                     )
             if info_list:
-                return min(info_list)[-1]
-            else:
-                return min(info_list2)[-1]
+                thumbnail_info = min(info_list)[-1]
+            elif info_list2:
+                thumbnail_info = min(info_list2)[-1]
+
+        if thumbnail_info:
+            return FileInfo(
+                file_id=file_id,
+                url_cache=url_cache,
+                server_name=server_name,
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+                thumbnail_length=thumbnail_info["thumbnail_length"],
+            )
+
+        # No matching thumbnail was found.
+        return None
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 32a8e4f960..07903e4017 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
 # limitations under the License.
 import logging
 from io import BytesIO
+from typing import Tuple
 
 from PIL import Image
 
@@ -39,7 +41,7 @@ class Thumbnailer:
 
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
-    def __init__(self, input_path):
+    def __init__(self, input_path: str):
         try:
             self.image = Image.open(input_path)
         except OSError as e:
@@ -59,11 +61,11 @@ class Thumbnailer:
             # A lot of parsing errors can happen when parsing EXIF
             logger.info("Error parsing image EXIF information: %s", e)
 
-    def transpose(self):
+    def transpose(self) -> Tuple[int, int]:
         """Transpose the image using its EXIF Orientation tag
 
         Returns:
-            Tuple[int, int]: (width, height) containing the new image size in pixels.
+            A tuple containing the new image size in pixels as (width, height).
         """
         if self.transpose_method is not None:
             self.image = self.image.transpose(self.transpose_method)
@@ -73,7 +75,7 @@ class Thumbnailer:
             self.image.info["exif"] = None
         return self.image.size
 
-    def aspect(self, max_width, max_height):
+    def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
         """Calculate the largest size that preserves aspect ratio which
         fits within the given rectangle::
 
@@ -91,7 +93,7 @@ class Thumbnailer:
         else:
             return (max_height * self.width) // self.height, max_height
 
-    def _resize(self, width, height):
+    def _resize(self, width: int, height: int) -> Image:
         # 1-bit or 8-bit color palette images need converting to RGB
         # otherwise they will be scaled using nearest neighbour which
         # looks awful
@@ -99,7 +101,7 @@ class Thumbnailer:
             self.image = self.image.convert("RGB")
         return self.image.resize((width, height), Image.ANTIALIAS)
 
-    def scale(self, width, height, output_type):
+    def scale(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales the image to the given dimensions.
 
         Returns:
@@ -108,7 +110,7 @@ class Thumbnailer:
         scaled = self._resize(width, height)
         return self._encode_image(scaled, output_type)
 
-    def crop(self, width, height, output_type):
+    def crop(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales and crops the image to the given dimensions preserving
         aspect::
             (w_in / h_in) = (w_scaled / h_scaled)
@@ -136,7 +138,7 @@ class Thumbnailer:
             cropped = scaled_image.crop((crop_left, 0, crop_right, height))
         return self._encode_image(cropped, output_type)
 
-    def _encode_image(self, output_image, output_type):
+    def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
         output_bytes_io = BytesIO()
         fmt = self.FORMATS[output_type]
         if fmt == "JPEG":
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 42febc9afc..6da76ae994 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,18 +15,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class UploadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
 
         self.media_repo = media_repo
@@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
         self.max_upload_size = hs.config.max_upload_size
         self.clock = hs.get_clock()
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_POST(self, request):
+    async def _async_render_POST(self, request: Request) -> None:
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
index e5b720bbca..9550b82998 100644
--- a/synapse/rest/synapse/client/pick_idp.py
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource):
         self._server_name = hs.hostname
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
-        client_redirect_url = parse_string(request, "redirectUrl", required=True)
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding="utf-8"
+        )
         idp = parse_string(request, "idp", required=False)
 
         # if we need to pick an IdP, do so
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index f591cc6c5c..241fe746d9 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -34,10 +34,6 @@ class WellKnownBuilder:
         self._config = hs.config
 
     def get_well_known(self):
-        # if we don't have a public_baseurl, we can't help much here.
-        if self._config.public_baseurl is None:
-            return None
-
         result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
 
         if self._config.default_identity_server:
diff --git a/synapse/server.py b/synapse/server.py
index a198b0eb46..9cdda83aa1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
 from synapse.federation.transport.client import TransportLayerClient
 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
 from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
+from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.acme import AcmeHandler
 from synapse.handlers.admin import AdminHandler
@@ -283,10 +284,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         """
         return self._reactor
 
-    def get_ip_from_request(self, request) -> str:
-        # X-Forwarded-For is handled by our custom request type.
-        return request.getClientIP()
-
     def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
         return domain_specific_string.domain == self.hostname
 
@@ -505,7 +502,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return InitialSyncHandler(self)
 
     @cache_in_self
-    def get_profile_handler(self):
+    def get_profile_handler(self) -> ProfileHandler:
         return ProfileHandler(self)
 
     @cache_in_self
@@ -715,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_module_api(self) -> ModuleApi:
         return ModuleApi(self, self.get_auth_handler())
 
+    @cache_in_self
+    def get_account_data_handler(self) -> AccountDataHandler:
+        return AccountDataHandler(self)
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 2258d306d9..8dd01fce76 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -42,6 +42,7 @@ class ResourceLimitsServerNotices:
         self._auth = hs.get_auth()
         self._config = hs.config
         self._resouce_limited = False
+        self._account_data_handler = hs.get_account_data_handler()
         self._message_handler = hs.get_message_handler()
         self._state = hs.get_state_handler()
 
@@ -177,7 +178,7 @@ class ResourceLimitsServerNotices:
                 # tag already present, nothing to do here
                 need_to_set_tag = False
         if need_to_set_tag:
-            max_id = await self._store.add_tag_to_room(
+            max_id = await self._account_data_handler.add_tag_to_room(
                 user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
             )
             self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 100dbd5e2c..c46b2f047d 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -35,6 +35,7 @@ class ServerNoticesManager:
 
         self._store = hs.get_datastore()
         self._config = hs.config
+        self._account_data_handler = hs.get_account_data_handler()
         self._room_creation_handler = hs.get_room_creation_handler()
         self._room_member_handler = hs.get_room_member_handler()
         self._event_creation_handler = hs.get_event_creation_handler()
@@ -163,7 +164,7 @@ class ServerNoticesManager:
         )
         room_id = info["room_id"]
 
-        max_id = await self._store.add_tag_to_room(
+        max_id = await self._account_data_handler.add_tag_to_room(
             user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
         )
         self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b70ca3087b..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import Collection
 
 # python 3 does not have a maximum int value
@@ -179,6 +180,9 @@ class LoggingDatabaseConnection:
 _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
 
 
+R = TypeVar("R")
+
+
 class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -258,13 +262,32 @@ class LoggingTransaction:
         return self.txn.description
 
     def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+        """Similar to `executemany`, except `txn.rowcount` will not be correct
+        afterwards.
+
+        More efficient than `executemany` on PostgreSQL
+        """
+
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
-            for val in args:
-                self.execute(sql, val)
+            self.executemany(sql, args)
+
+    def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+        """Corresponds to psycopg2.extras.execute_values. Only available when
+        using postgres.
+
+        Always sets fetch=True when caling `execute_values`, so will return the
+        results.
+        """
+        assert isinstance(self.database_engine, PostgresEngine)
+        from psycopg2.extras import execute_values  # type: ignore
+
+        return self._do_execute(
+            lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+        )
 
     def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
@@ -276,7 +299,7 @@ class LoggingTransaction:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql: str, *args: Any) -> None:
+    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -347,9 +370,6 @@ class PerformanceCounters:
         return top_n_counters
 
 
-R = TypeVar("R")
-
-
 class DatabasePool:
     """Wraps a single physical database and connection pool.
 
@@ -398,6 +418,16 @@ class DatabasePool:
                 self._check_safe_to_upsert,
             )
 
+        # We define this sequence here so that it can be referenced from both
+        # the DataStore and PersistEventStore.
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self.event_chain_id_gen = build_sequence_generator(
+            engine, get_chain_id_txn, "event_auth_chain_id"
+        )
+
     def is_running(self) -> bool:
         """Is the database pool currently running
         """
@@ -863,7 +893,7 @@ class DatabasePool:
             ", ".join("?" for _ in keys[0]),
         )
 
-        txn.executemany(sql, vals)
+        txn.execute_batch(sql, vals)
 
     async def simple_upsert(
         self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index b936f54f1e..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -162,9 +162,13 @@ class DataStore(
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index bff51e92b9..a277a1ef13 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,14 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
 from typing import Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
     """This is an abstract base class where subclasses must implement
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_account_data = (
+                self._instance_name in hs.config.worker.writers.account_data
+            )
+
+            self._account_data_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="account_data",
+                instance_name=self._instance_name,
+                tables=[
+                    ("room_account_data", "instance_name", "stream_id"),
+                    ("room_tags_revisions", "instance_name", "stream_id"),
+                    ("account_data", "instance_name", "stream_id"),
+                ],
+                sequence_name="account_data_sequence",
+                writers=hs.config.worker.writers.account_data,
+            )
+        else:
+            self._can_write_to_account_data = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.account_data:
+                self._account_data_id_gen = StreamIdGenerator(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+            else:
+                self._account_data_id_gen = SlavedIdTracker(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         super().__init__(database, db_conn, hs)
 
-    @abc.abstractmethod
-    def get_max_account_data_stream_id(self):
+    def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._account_data_id_gen.get_current_token()
 
     @cached()
     async def get_account_data_for_user(
@@ -307,28 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
             )
         )
 
-
-class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn,
-            "account_data_max_stream_id",
-            "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self) -> int:
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            The maximum stream ID.
-        """
-        return self._account_data_id_gen.get_current_token()
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        elif stream_name == AccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                if not row.room_id:
+                    self.get_global_account_data_by_type_for_user.invalidate(
+                        (row.data_type, row.user_id)
+                    )
+                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+                self.get_account_data_for_room_and_type.invalidate(
+                    (row.user_id, row.room_id, row.data_type)
+                )
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -344,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         async with self._account_data_id_gen.get_next() as next_id:
@@ -362,14 +406,6 @@ class AccountDataStore(AccountDataWorkerStore):
                 lock=False,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
@@ -392,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
+        assert self._can_write_to_account_data
+
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "add_user_account_data",
@@ -402,18 +440,6 @@ class AccountDataStore(AccountDataWorkerStore):
                 content,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            #
-            # Note: This is only here for backwards compat to allow admins to
-            # roll back to a previous Synapse version. Next time we update the
-            # database version we can remove this table.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
@@ -487,23 +513,6 @@ class AccountDataStore(AccountDataWorkerStore):
         for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
 
-    async def _update_max_stream_id(self, next_id: int) -> None:
-        """Update the max stream_id
-
-        Args:
-            next_id: The the revision to advance to.
-        """
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
-
-        def _update(txn):
-            update_max_id_sql = (
-                "UPDATE account_data_max_stream_id"
-                " SET stream_id = ?"
-                " WHERE stream_id < ?"
-            )
-            txn.execute(update_max_id_sql, (next_id, next_id))
-
-        await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+class AccountDataStore(AccountDataWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c53c836337..ea1e8fb580 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -407,6 +407,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             "_prune_old_user_ips", _prune_old_user_ips_txn
         )
 
+    async def get_last_client_ip_by_device(
+        self, user_id: str, device_id: Optional[str]
+    ) -> Dict[Tuple[str, str], dict]:
+        """For each device_id listed, give the user_ip it was last seen on.
+
+        The result might be slightly out of date as client IPs are inserted in batches.
+
+        Args:
+            user_id: The user to fetch devices for.
+            device_id: If None fetches all devices for the user
+
+        Returns:
+            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+            keys giving the column names from the devices table.
+        """
+
+        keyvalues = {"user_id": user_id}
+        if device_id is not None:
+            keyvalues["device_id"] = device_id
+
+        res = await self.db_pool.simple_select_list(
+            table="devices",
+            keyvalues=keyvalues,
+            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        )
+
+        return {(d["user_id"], d["device_id"]): d for d in res}
+
 
 class ClientIpStore(ClientIpWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -512,18 +540,9 @@ class ClientIpStore(ClientIpWorkerStore):
             A dictionary mapping a tuple of (user_id, device_id) to dicts, with
             keys giving the column names from the devices table.
         """
+        ret = await super().get_last_client_ip_by_device(user_id, device_id)
 
-        keyvalues = {"user_id": user_id}
-        if device_id is not None:
-            keyvalues["device_id"] = device_id
-
-        res = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues=keyvalues,
-            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
-        )
-
-        ret = {(d["user_id"], d["device_id"]): d for d in res}
+        # Update what is retrieved from the database with data which is pending insertion.
         for key in self._batch_row_update:
             uid, access_token, ip = key
             if uid == user_id:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 58d3f71e45..31f70ac5ef 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="to_device",
                 instance_name=self._instance_name,
-                table="device_inbox",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("device_inbox", "instance_name", "stream_id")],
                 sequence_name="device_inbox_sequence",
                 writers=hs.config.worker.writers.to_device,
             )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 DELETE FROM device_lists_outbound_last_success
                 WHERE destination = ? AND user_id = ?
             """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+            txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
 
             logger.info("Pruned %d device list outbound pokes", count)
 
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
-        txn.executemany(
+        txn.execute_batch(
             """
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4d1b92d1aa..c128889bf9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         for user_chunk in batch_iter(user_ids, 100):
             clause, params = make_in_list_sql_clause(
-                txn.database_engine, "k.user_id", user_chunk
-            )
-            sql = (
-                """
-                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
-                  FROM e2e_cross_signing_keys k
-                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
-                                FROM e2e_cross_signing_keys
-                               GROUP BY user_id, keytype) s
-                 USING (user_id, stream_id, keytype)
-                 WHERE
-            """
-                + clause
+                txn.database_engine, "user_id", user_chunk
             )
 
+            # Fetch the latest key for each type per user.
+            if isinstance(self.database_engine, PostgresEngine):
+                # The `DISTINCT ON` clause will pick the *first* row it
+                # encounters, so ordering by stream ID desc will ensure we get
+                # the latest key.
+                sql = """
+                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        ORDER BY user_id, keytype, stream_id DESC
+                """ % {
+                    "clause": clause
+                }
+            else:
+                # SQLite has special handling for bare columns when using
+                # MIN/MAX with a `GROUP BY` clause where it picks the value from
+                # a row that matches the MIN/MAX.
+                sql = """
+                    SELECT user_id, keytype, keydata, MAX(stream_id)
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        GROUP BY user_id, keytype
+                """ % {
+                    "clause": clause
+                }
+
             txn.execute(sql, params)
             rows = self.db_pool.cursor_to_dict(txn)
 
@@ -707,50 +722,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         """Get the current stream id from the _device_list_id_gen"""
         ...
 
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    async def set_e2e_device_keys(
-        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
-    ) -> bool:
-        """Stores device keys for a device. Returns whether there was a change
-        or the keys were already in the database.
-        """
-
-        def _set_e2e_device_keys_txn(txn):
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
-
-            old_key_json = self.db_pool.simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
-
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
-
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
-
-        return await self.db_pool.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
-        )
-
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
     ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
@@ -840,6 +811,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    async def set_e2e_device_keys(
+        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+    ) -> bool:
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+
+        def _set_e2e_device_keys_txn(txn):
+            set_tag("user_id", user_id)
+            set_tag("device_id", device_id)
+            set_tag("time_now", time_now)
+            set_tag("device_keys", device_keys)
+
+            old_key_json = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            # In py3 we need old_key_json to match new_key_json type. The DB
+            # returns unicode while encode_canonical_json returns bytes.
+            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+            if old_key_json == new_key_json:
+                log_kv({"Message": "Device key already stored."})
+                return False
+
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                values={"ts_added_ms": time_now, "key_json": new_key_json},
+            )
+            log_kv({"message": "Device keys stored."})
+            return True
+
+        return await self.db_pool.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        )
+
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn):
             log_kv(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ebffd89251..8326640d20 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Cursor
 from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
 logger = logging.getLogger(__name__)
 
 
+class _NoChainCoverIndex(Exception):
+    def __init__(self, room_id: str):
+        super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
+
+
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             The set of the difference in auth chains.
         """
 
+        # Check if we have indexed the room so we can use the chain cover
+        # algorithm.
+        room = await self.get_room(room_id)
+        if room["has_auth_chain_index"]:
+            try:
+                return await self.db_pool.runInteraction(
+                    "get_auth_chain_difference_chains",
+                    self._get_auth_chain_difference_using_cover_index_txn,
+                    room_id,
+                    state_sets,
+                )
+            except _NoChainCoverIndex:
+                # For whatever reason we don't actually have a chain cover index
+                # for the events in question, so we fall back to the old method.
+                pass
+
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
             self._get_auth_chain_difference_txn,
             state_sets,
         )
 
+    def _get_auth_chain_difference_using_cover_index_txn(
+        self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+    ) -> Set[str]:
+        """Calculates the auth chain difference using the chain index.
+
+        See docs/auth_chain_difference_algorithm.md for details
+        """
+
+        # First we look up the chain ID/sequence numbers for all the events, and
+        # work out the chain/sequence numbers reachable from each state set.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Map from event_id -> (chain ID, seq no)
+        chain_info = {}  # type: Dict[str, Tuple[int, int]]
+
+        # Map from chain ID -> seq no -> event Id
+        chain_to_event = {}  # type: Dict[int, Dict[int, str]]
+
+        # All the chains that we've found that are reachable from the state
+        # sets.
+        seen_chains = set()  # type: Set[int]
+
+        sql = """
+            SELECT event_id, chain_id, sequence_number
+            FROM event_auth_chains
+            WHERE %s
+        """
+        for batch in batch_iter(initial_events, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for event_id, chain_id, sequence_number in txn:
+                chain_info[event_id] = (chain_id, sequence_number)
+                seen_chains.add(chain_id)
+                chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+        # Check that we actually have a chain ID for all the events.
+        events_missing_chain_info = initial_events.difference(chain_info)
+        if events_missing_chain_info:
+            # This can happen due to e.g. downgrade/upgrade of the server. We
+            # raise an exception and fall back to the previous algorithm.
+            logger.info(
+                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+                room_id,
+                events_missing_chain_info,
+            )
+            raise _NoChainCoverIndex(room_id)
+
+        # Corresponds to `state_sets`, except as a map from chain ID to max
+        # sequence number reachable from the state set.
+        set_to_chain = []  # type: List[Dict[int, int]]
+        for state_set in state_sets:
+            chains = {}  # type: Dict[int, int]
+            set_to_chain.append(chains)
+
+            for event_id in state_set:
+                chain_id, seq_no = chain_info[event_id]
+
+                chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
+
+        # Now we look up all links for the chains we have, adding chains to
+        # set_to_chain that are reachable from each set.
+        sql = """
+            SELECT
+                origin_chain_id, origin_sequence_number,
+                target_chain_id, target_sequence_number
+            FROM event_auth_chain_links
+            WHERE %s
+        """
+
+        # (We need to take a copy of `seen_chains` as we want to mutate it in
+        # the loop)
+        for batch in batch_iter(set(seen_chains), 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "origin_chain_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for (
+                origin_chain_id,
+                origin_sequence_number,
+                target_chain_id,
+                target_sequence_number,
+            ) in txn:
+                for chains in set_to_chain:
+                    # chains are only reachable if the origin sequence number of
+                    # the link is less than the max sequence number in the
+                    # origin chain.
+                    if origin_sequence_number <= chains.get(origin_chain_id, 0):
+                        chains[target_chain_id] = max(
+                            target_sequence_number, chains.get(target_chain_id, 0),
+                        )
+
+                seen_chains.add(target_chain_id)
+
+        # Now for each chain we figure out the maximum sequence number reachable
+        # from *any* state set and the minimum sequence number reachable from
+        # *all* state sets. Events in that range are in the auth chain
+        # difference.
+        result = set()
+
+        # Mapping from chain ID to the range of sequence numbers that should be
+        # pulled from the database.
+        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]
+
+        for chain_id in seen_chains:
+            min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
+            max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+
+            if min_seq_no < max_seq_no:
+                # We have a non empty gap, try and fill it from the events that
+                # we have, otherwise add them to the list of gaps to pull out
+                # from the DB.
+                for seq_no in range(min_seq_no + 1, max_seq_no + 1):
+                    event_id = chain_to_event.get(chain_id, {}).get(seq_no)
+                    if event_id:
+                        result.add(event_id)
+                    else:
+                        chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
+                        break
+
+        if not chain_to_gap:
+            # If there are no gaps to fetch, we're done!
+            return result
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND min_seq < sequence_number AND sequence_number <= max_seq
+            """
+
+            args = [
+                (chain_id, min_no, max_no)
+                for chain_id, (min_no, max_no) in chain_to_gap.items()
+            ]
+
+            rows = txn.execute_values(sql, args)
+            result.update(r for r, in rows)
+        else:
+            # For SQLite we just fall back to doing a noddy for loop.
+            sql = """
+                SELECT event_id FROM event_auth_chains
+                WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
+            """
+            for chain_id, (min_no, max_no) in chain_to_gap.items():
+                txn.execute(sql, (chain_id, min_no, max_no))
+                result.update(r for r, in txn)
+
+        return result
+
     def _get_auth_chain_difference_txn(
         self, txn, state_sets: List[Set[str]]
     ) -> Set[str]:
+        """Calculates the auth chain difference using a breadth first search.
+
+        This is used when we don't have a cover index for the room.
+        """
 
         # Algorithm Description
         # ~~~~~~~~~~~~~~~~~~~~~
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index e5c03cc609..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 VALUES (?, ?, ?, ?, ?, ?)
             """
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     _gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             ],
         )
 
-        txn.executemany(
+        txn.execute_batch(
             """
                 UPDATE event_push_summary
                 SET notif_count = ?, unread_count = ?, stream_ordering = ?
@@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             (rotate_to_stream_ordering,),
         )
 
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
+        """
+        Purges old push actions for a user and room before a given
+        stream_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
+        Args:
+            txn: The transcation
+            room_id: Room ID to delete from
+            user_id: user ID to delete for
+            stream_ordering: The lowest stream ordering which will
+                                  not be deleted.
+        """
+        txn.call_after(
+            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id, user_id),
+        )
+
+        # We need to join on the events table to get the received_ts for
+        # event_push_actions and sqlite won't let us use a join in a delete so
+        # we can't just delete where received_ts < x. Furthermore we can
+        # only identify event_push_actions by a tuple of room_id, event_id
+        # we we can't use a subquery.
+        # Instead, we look up the stream ordering for the last event in that
+        # room received before the threshold time and delete event_push_actions
+        # in the room with a stream_odering before that.
+        txn.execute(
+            "DELETE FROM event_push_actions "
+            " WHERE user_id = ? AND room_id = ? AND "
+            " stream_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        )
+
+        txn.execute(
+            """
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """,
+            (room_id, user_id, stream_ordering),
+        )
+
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
 
-    def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transcation
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
-
 
 def _action_has_highlight(actions):
     for action in actions:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 90fb1a1f00..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,17 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -35,7 +45,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
-from synapse.util.iterutils import batch_iter
+from synapse.util.iterutils import batch_iter, sorted_topologically
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -366,6 +376,36 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
+        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn, events_and_contexts=events_and_contexts
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # We call this last as it assumes we've inserted the events into
+        # room_memberships, where applicable.
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+    def _persist_event_auth_chain_txn(
+        self, txn: LoggingTransaction, events: List[EventBase],
+    ) -> None:
+
+        # We only care about state events, so this if there are no state events.
+        if not any(e.is_state() for e in events):
+            return
+
         # We want to store event_auth mappings for rejected events, as they're
         # used in state res v2.
         # This is only necessary if the rejected event appears in an accepted
@@ -381,31 +421,467 @@ class PersistEventsStore:
                     "room_id": event.room_id,
                     "auth_id": auth_id,
                 }
-                for event, _ in events_and_contexts
+                for event in events
                 for auth_id in event.auth_event_ids()
                 if event.is_state()
             ],
         )
 
-        # _store_rejected_events_txn filters out any events which were
-        # rejected, and returns the filtered list.
-        events_and_contexts = self._store_rejected_events_txn(
-            txn, events_and_contexts=events_and_contexts
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
+
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="rooms",
+            column="room_id",
+            iterable={event.room_id for event in events if event.is_state()},
+            keyvalues={},
+            retcols=("room_id", "has_auth_chain_index"),
         )
+        rooms_using_chain_index = {
+            row["room_id"] for row in rows if row["has_auth_chain_index"]
+        }
 
-        # From this point onwards the events are only ones that weren't
-        # rejected.
+        state_events = {
+            event.event_id: event
+            for event in events
+            if event.is_state() and event.room_id in rooms_using_chain_index
+        }
 
-        self._update_metadata_tables_txn(
+        if not state_events:
+            return
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {
+            e.event_id: (e.type, e.state_key) for e in state_events.values()
+        }
+        event_to_auth_chain = {
+            e.event_id: e.auth_event_ids() for e in state_events.values()
+        }
+        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+        self._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+    @classmethod
+    def _add_chain_cover_index(
+        cls,
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+    ) -> None:
+        """Calculate the chain cover index for the given events.
+
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
+
+        # Map from event ID to chain ID/sequence number.
+        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+
+        # Set of event IDs to calculate chain ID/seq numbers for.
+        events_to_calc_chain_id_for = set(event_to_room_id)
+
+        # We check if there are any events that need to be handled in the rooms
+        # we're looking at. These should just be out of band memberships, where
+        # we didn't have the auth chain when we first persisted.
+        rows = db_pool.simple_select_many_txn(
             txn,
-            events_and_contexts=events_and_contexts,
-            all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="room_id",
+            iterable=set(event_to_room_id.values()),
+            retcols=("event_id", "type", "state_key"),
         )
+        for row in rows:
+            event_id = row["event_id"]
+            event_type = row["type"]
+            state_key = row["state_key"]
+
+            # (We could pull out the auth events for all rows at once using
+            # simple_select_many, but this case happens rarely and almost always
+            # with a single row.)
+            auth_events = db_pool.simple_select_onecol_txn(
+                txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+            )
 
-        # We call this last as it assumes we've inserted the events into
-        # room_memberships, where applicable.
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+            events_to_calc_chain_id_for.add(event_id)
+            event_to_types[event_id] = (event_type, state_key)
+            event_to_auth_chain[event_id] = auth_events
+
+        # First we get the chain ID and sequence numbers for the events'
+        # auth events (that aren't also currently being persisted).
+        #
+        # Note that there there is an edge case here where we might not have
+        # calculated chains and sequence numbers for events that were "out
+        # of band". We handle this case by fetching the necessary info and
+        # adding it to the set of events to calculate chain IDs for.
+
+        missing_auth_chains = {
+            a_id
+            for auth_events in event_to_auth_chain.values()
+            for a_id in auth_events
+            if a_id not in events_to_calc_chain_id_for
+        }
+
+        # We loop here in case we find an out of band membership and need to
+        # fetch their auth event info.
+        while missing_auth_chains:
+            sql = """
+                SELECT event_id, events.type, state_key, chain_id, sequence_number
+                FROM events
+                INNER JOIN state_events USING (event_id)
+                LEFT JOIN event_auth_chains USING (event_id)
+                WHERE
+            """
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", missing_auth_chains,
+            )
+            txn.execute(sql + clause, args)
+
+            missing_auth_chains.clear()
+
+            for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+                event_to_types[auth_id] = (event_type, state_key)
+
+                if chain_id is None:
+                    # No chain ID, so the event was persisted out of band.
+                    # We add to list of events to calculate auth chains for.
+
+                    events_to_calc_chain_id_for.add(auth_id)
+
+                    event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
+                        txn,
+                        "event_auth",
+                        keyvalues={"event_id": auth_id},
+                        retcol="auth_id",
+                    )
+
+                    missing_auth_chains.update(
+                        e
+                        for e in event_to_auth_chain[auth_id]
+                        if e not in event_to_types
+                    )
+                else:
+                    chain_map[auth_id] = (chain_id, sequence_number)
+
+        # Now we check if we have any events where we don't have auth chain,
+        # this should only be out of band memberships.
+        for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
+            for auth_id in event_to_auth_chain[event_id]:
+                if (
+                    auth_id not in chain_map
+                    and auth_id not in events_to_calc_chain_id_for
+                ):
+                    events_to_calc_chain_id_for.discard(event_id)
+
+                    # If this is an event we're trying to persist we add it to
+                    # the list of events to calculate chain IDs for next time
+                    # around. (Otherwise we will have already added it to the
+                    # table).
+                    room_id = event_to_room_id.get(event_id)
+                    if room_id:
+                        e_type, state_key = event_to_types[event_id]
+                        db_pool.simple_insert_txn(
+                            txn,
+                            table="event_auth_chain_to_calculate",
+                            values={
+                                "event_id": event_id,
+                                "room_id": room_id,
+                                "type": e_type,
+                                "state_key": state_key,
+                            },
+                        )
+
+                    # We stop checking the event's auth events since we've
+                    # discarded it.
+                    break
+
+        if not events_to_calc_chain_id_for:
+            return
+
+        # Allocate chain ID/sequence numbers to each new event.
+        new_chain_tuples = cls._allocate_chain_ids(
+            txn,
+            db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+            events_to_calc_chain_id_for,
+            chain_map,
+        )
+        chain_map.update(new_chain_tuples)
+
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            values=[
+                {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+                for event_id, (c_id, seq) in new_chain_tuples.items()
+            ],
+        )
+
+        db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            iterable=new_chain_tuples,
+        )
+
+        # Now we need to calculate any new links between chains caused by
+        # the new events.
+        #
+        # Links are pairs of chain ID/sequence numbers such that for any
+        # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
+        # if and only if there is at least one link (CA, S1) -> (CB, S2)
+        # where SA >= S1 and S2 >= SB.
+        #
+        # We try and avoid adding redundant links to the table, e.g. if we
+        # have two links between two chains which both start/end at the
+        # sequence number event (or cross) then one can be safely dropped.
+        #
+        # To calculate new links we look at every new event and:
+        #   1. Fetch the chain ID/sequence numbers of its auth events,
+        #      discarding any that are reachable by other auth events, or
+        #      that have the same chain ID as the event.
+        #   2. For each retained auth event we:
+        #       a. Add a link from the event's to the auth event's chain
+        #          ID/sequence number; and
+        #       b. Add a link from the event to every chain reachable by the
+        #          auth event.
+
+        # Step 1, fetch all existing links from all the chains we've seen
+        # referenced.
+        chain_links = _LinkMap()
+        rows = db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            column="origin_chain_id",
+            iterable={chain_id for chain_id, _ in chain_map.values()},
+            keyvalues={},
+            retcols=(
+                "origin_chain_id",
+                "origin_sequence_number",
+                "target_chain_id",
+                "target_sequence_number",
+            ),
+        )
+        for row in rows:
+            chain_links.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+                new=False,
+            )
+
+        # We do this in toplogical order to avoid adding redundant links.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            chain_id, sequence_number = chain_map[event_id]
+
+            # Filter out auth events that are reachable by other auth
+            # events. We do this by looking at every permutation of pairs of
+            # auth events (A, B) to check if B is reachable from A.
+            reduction = {
+                a_id
+                for a_id in event_to_auth_chain.get(event_id, [])
+                if chain_map[a_id][0] != chain_id
+            }
+            for start_auth_id, end_auth_id in itertools.permutations(
+                event_to_auth_chain.get(event_id, []), r=2,
+            ):
+                if chain_links.exists_path_from(
+                    chain_map[start_auth_id], chain_map[end_auth_id]
+                ):
+                    reduction.discard(end_auth_id)
+
+            # Step 2, figure out what the new links are from the reduced
+            # list of auth events.
+            for auth_id in reduction:
+                auth_chain_id, auth_sequence_number = chain_map[auth_id]
+
+                # Step 2a, add link between the event and auth event
+                chain_links.add_link(
+                    (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
+                )
+
+                # Step 2b, add a link to chains reachable from the auth
+                # event.
+                for target_id, target_seq in chain_links.get_links_from(
+                    (auth_chain_id, auth_sequence_number)
+                ):
+                    if target_id == chain_id:
+                        continue
+
+                    chain_links.add_link(
+                        (chain_id, sequence_number), (target_id, target_seq)
+                    )
+
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            values=[
+                {
+                    "origin_chain_id": source_id,
+                    "origin_sequence_number": source_seq,
+                    "target_chain_id": target_id,
+                    "target_sequence_number": target_seq,
+                }
+                for (
+                    source_id,
+                    source_seq,
+                    target_id,
+                    target_seq,
+                ) in chain_links.get_additions()
+            ],
+        )
+
+    @staticmethod
+    def _allocate_chain_ids(
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+        events_to_calc_chain_id_for: Set[str],
+        chain_map: Dict[str, Tuple[int, int]],
+    ) -> Dict[str, Tuple[int, int]]:
+        """Allocates, but does not persist, chain ID/sequence numbers for the
+        events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+        for info on args)
+        """
+
+        # We now calculate the chain IDs/sequence numbers for the events. We do
+        # this by looking at the chain ID and sequence number of any auth event
+        # with the same type/state_key and incrementing the sequence number by
+        # one. If there was no match or the chain ID/sequence number is already
+        # taken we generate a new chain.
+        #
+        # We try to reduce the number of times that we hit the database by
+        # batching up calls, to make this more efficient when persisting large
+        # numbers of state events (e.g. during joins).
+        #
+        # We do this by:
+        #   1. Calculating for each event which auth event will be used to
+        #      inherit the chain ID, i.e. converting the auth chain graph to a
+        #      tree that we can allocate chains on. We also keep track of which
+        #      existing chain IDs have been referenced.
+        #   2. Fetching the max allocated sequence number for each referenced
+        #      existing chain ID, generating a map from chain ID to the max
+        #      allocated sequence number.
+        #   3. Iterating over the tree and allocating a chain ID/seq no. to the
+        #      new event, by incrementing the sequence number from the
+        #      referenced event's chain ID/seq no. and checking that the
+        #      incremented sequence number hasn't already been allocated (by
+        #      looking in the map generated in the previous step). We generate a
+        #      new chain if the sequence number has already been allocated.
+        #
+
+        existing_chains = set()  # type: Set[int]
+        tree = []  # type: List[Tuple[str, Optional[str]]]
+
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events before
+        # the event itself.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            for auth_id in event_to_auth_chain.get(event_id, []):
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map.get(auth_id)
+                    if existing_chain_id:
+                        existing_chains.add(existing_chain_id[0])
+
+                    tree.append((event_id, auth_id))
+                    break
+            else:
+                tree.append((event_id, None))
+
+        # Fetch the current max sequence number for each existing referenced chain.
+        sql = """
+            SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+            WHERE %s
+            GROUP BY chain_id
+        """
+        clause, args = make_in_list_sql_clause(
+            db_pool.engine, "chain_id", existing_chains
+        )
+        txn.execute(sql % (clause,), args)
+
+        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+
+        # Allocate the new events chain ID/sequence numbers.
+        #
+        # To reduce the number of calls to the database we don't allocate a
+        # chain ID number in the loop, instead we use a temporary `object()` for
+        # each new chain ID. Once we've done the loop we generate the necessary
+        # number of new chain IDs in one call, replacing all temporary
+        # objects with real allocated chain IDs.
+
+        unallocated_chain_ids = set()  # type: Set[object]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        for event_id, auth_event_id in tree:
+            # If we reference an auth_event_id we fetch the allocated chain ID,
+            # either from the existing `chain_map` or the newly generated
+            # `new_chain_tuples` map.
+            existing_chain_id = None
+            if auth_event_id:
+                existing_chain_id = new_chain_tuples.get(auth_event_id)
+                if not existing_chain_id:
+                    existing_chain_id = chain_map[auth_event_id]
+
+            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+
+                if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+                    new_chain_tuple = (
+                        proposed_new_id,
+                        proposed_new_seq,
+                    )
+
+            # If we need to start a new chain we allocate a temporary chain ID.
+            if not new_chain_tuple:
+                new_chain_tuple = (object(), 1)
+                unallocated_chain_ids.add(new_chain_tuple[0])
+
+            new_chain_tuples[event_id] = new_chain_tuple
+            chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+        # Generate new chain IDs for all unallocated chain IDs.
+        newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+            txn, len(unallocated_chain_ids)
+        )
+
+        # Map from potentially temporary chain ID to real chain ID
+        chain_id_to_allocated_map = dict(
+            zip(unallocated_chain_ids, newly_allocated_chain_ids)
+        )  # type: Dict[Any, int]
+        chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+        return {
+            event_id: (chain_id_to_allocated_map[chain_id], seq)
+            for event_id, (chain_id, seq) in new_chain_tuples.items()
+        }
 
     def _persist_transaction_ids_txn(
         self,
@@ -489,7 +965,7 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.executemany(
+                txn.execute_batch(
                     sql,
                     (
                         (
@@ -508,7 +984,7 @@ class PersistEventsStore:
                 )
                 # Now we actually update the current_state_events table
 
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM current_state_events"
                     " WHERE room_id = ? AND type = ? AND state_key = ?",
                     (
@@ -520,7 +996,7 @@ class PersistEventsStore:
                 # We include the membership in the current state table, hence we do
                 # a lookup when we insert. This assumes that all events have already
                 # been inserted into room_memberships.
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership)
                     VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -540,7 +1016,7 @@ class PersistEventsStore:
             # we have no record of the fact the user *was* a member of the
             # room but got, say, state reset out of it.
             if to_delete or to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM local_current_membership"
                     " WHERE room_id = ? AND user_id = ?",
                     (
@@ -551,7 +1027,7 @@ class PersistEventsStore:
                 )
 
             if to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership)
                     VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -799,7 +1275,8 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _store_event_txn(self, txn, events_and_contexts):
-        """Insert new events into the event and event_json tables
+        """Insert new events into the event, event_json, redaction and
+        state_events tables.
 
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
@@ -871,6 +1348,29 @@ class PersistEventsStore:
                     updatevalues={"have_censored": False},
                 )
 
+        state_events_and_contexts = [
+            ec for ec in events_and_contexts if ec[0].is_state()
+        ]
+
+        state_values = []
+        for event, context in state_events_and_contexts:
+            vals = {
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "type": event.type,
+                "state_key": event.state_key,
+            }
+
+            # TODO: How does this work with backfilling?
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
+
+            state_values.append(vals)
+
+        self.db_pool.simple_insert_many_txn(
+            txn, table="state_events", values=state_values
+        )
+
     def _store_rejected_events_txn(self, txn, events_and_contexts):
         """Add rows to the 'rejections' table for received events which were
         rejected
@@ -987,29 +1487,6 @@ class PersistEventsStore:
             txn, [event for event, _ in events_and_contexts]
         )
 
-        state_events_and_contexts = [
-            ec for ec in events_and_contexts if ec[0].is_state()
-        ]
-
-        state_values = []
-        for event, context in state_events_and_contexts:
-            vals = {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "type": event.type,
-                "state_key": event.state_key,
-            }
-
-            # TODO: How does this work with backfilling?
-            if hasattr(event, "replaces_state"):
-                vals["prev_state"] = event.replaces_state
-
-            state_values.append(vals)
-
-        self.db_pool.simple_insert_many_txn(
-            txn, table="state_events", values=state_values
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
@@ -1350,7 +1827,7 @@ class PersistEventsStore:
         """
 
         if events_and_contexts:
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (
@@ -1379,7 +1856,7 @@ class PersistEventsStore:
 
         # Now we delete the staging area for *all* events that were being
         # persisted.
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM event_push_actions_staging WHERE event_id = ?",
             ((event.event_id,) for event, _ in all_events_and_contexts),
         )
@@ -1498,7 +1975,7 @@ class PersistEventsStore:
             " )"
         )
 
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1512,7 +1989,7 @@ class PersistEventsStore:
             "DELETE FROM event_backward_extremities"
             " WHERE event_id = ? AND room_id = ?"
         )
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (ev.event_id, ev.room_id)
@@ -1520,3 +1997,131 @@ class PersistEventsStore:
                 if not ev.internal_metadata.is_outlier()
             ],
         )
+
+
+@attr.s(slots=True)
+class _LinkMap:
+    """A helper type for tracking links between chains.
+    """
+
+    # Stores the set of links as nested maps: source chain ID -> target chain ID
+    # -> source sequence number -> target sequence number.
+    maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+
+    # Stores the links that have been added (with new set to true), as tuples of
+    # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
+    additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+
+    def add_link(
+        self,
+        src_tuple: Tuple[int, int],
+        target_tuple: Tuple[int, int],
+        new: bool = True,
+    ) -> bool:
+        """Add a new link between two chains, ensuring no redundant links are added.
+
+        New links should be added in topological order.
+
+        Args:
+            src_tuple: The chain ID/sequence number of the source of the link.
+            target_tuple: The chain ID/sequence number of the target of the link.
+            new: Whether this is a "new" link, i.e. should it be returned
+                by `get_additions`.
+
+        Returns:
+            True if a link was added, false if the given link was dropped as redundant
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
+
+        assert src_chain != target_chain
+
+        if new:
+            # Check if the new link is redundant
+            for current_seq_src, current_seq_target in current_links.items():
+                # If a link "crosses" another link then its redundant. For example
+                # in the following link 1 (L1) is redundant, as any event reachable
+                # via L1 is *also* reachable via L2.
+                #
+                #   Chain A     Chain B
+                #      |          |
+                #   L1 |------    |
+                #      |     |    |
+                #   L2 |---- | -->|
+                #      |     |    |
+                #      |     |--->|
+                #      |          |
+                #      |          |
+                #
+                # So we only need to keep links which *do not* cross, i.e. links
+                # that both start and end above or below an existing link.
+                #
+                # Note, since we add links in topological ordering we should never
+                # see `src_seq` less than `current_seq_src`.
+
+                if current_seq_src <= src_seq and target_seq <= current_seq_target:
+                    # This new link is redundant, nothing to do.
+                    return False
+
+            self.additions.add((src_chain, src_seq, target_chain, target_seq))
+
+        current_links[src_seq] = target_seq
+        return True
+
+    def get_links_from(
+        self, src_tuple: Tuple[int, int]
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the chains reachable from the given chain/sequence number.
+
+        Yields:
+            The chain ID and sequence number the link points to.
+        """
+        src_chain, src_seq = src_tuple
+        for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
+            for link_src_seq, target_seq in sequence_numbers.items():
+                if link_src_seq <= src_seq:
+                    yield target_id, target_seq
+
+    def get_links_between(
+        self, source_chain: int, target_chain: int
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the links between two chains.
+
+        Yields:
+            The source and target sequence numbers.
+        """
+
+        yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
+
+    def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
+        """Gets any newly added links.
+
+        Yields:
+            The source chain ID/sequence number and target chain ID/sequence number
+        """
+
+        for src_chain, src_seq, target_chain, _ in self.additions:
+            target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
+            if target_seq is not None:
+                yield (src_chain, src_seq, target_chain, target_seq)
+
+    def exists_path_from(
+        self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+    ) -> bool:
+        """Checks if there is a path between the source chain ID/sequence and
+        target chain ID/sequence.
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        if src_chain == target_chain:
+            return target_seq <= src_seq
+
+        links = self.get_links_between(src_chain, target_chain)
+        for link_start_seq, link_end_seq in links:
+            if link_start_seq <= src_seq and target_seq <= link_end_seq:
+                return True
+
+        return False
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 97b6754846..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,14 +14,41 @@
 # limitations under the License.
 
 import logging
+from typing import Dict, List, Optional, Tuple
+
+import attr
 
 from synapse.api.constants import EventContentFields
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.databases.main.events import PersistEventsStore
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+    """Return value for _calculate_chain_cover_txn.
+    """
+
+    # The last room_id/depth/stream processed.
+    room_id = attr.ib(type=str)
+    depth = attr.ib(type=int)
+    stream = attr.ib(type=int)
+
+    # Number of rows processed
+    processed_count = attr.ib(type=int)
+
+    # Map from room_id to last depth/stream processed for each room that we have
+    # processed all events for (i.e. the rooms we can flip the
+    # `has_auth_chain_index` for)
+    finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
 class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -99,13 +126,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             columns=["user_id", "created_ts"],
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "rejected_events_metadata", self._rejected_events_metadata,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            "chain_cover", self._chain_cover_index,
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
@@ -143,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
-            for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, update_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -175,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_search_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
@@ -221,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
-            for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, rows_to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -582,3 +609,314 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             await self.db_pool.updates._end_background_update("event_store_labels")
 
         return num_rows
+
+    async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
+        """Adds rejected events to the `state_events` and `event_auth` metadata
+        tables.
+        """
+
+        last_event_id = progress.get("last_event_id", "")
+
+        def get_rejected_events(
+            txn: Cursor,
+        ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
+            # Fetch rejected event json, their room version and whether we have
+            # inserted them into the state_events or auth_events tables.
+            #
+            # Note we can assume that events that don't have a corresponding
+            # room version are V1 rooms.
+            sql = """
+                SELECT DISTINCT
+                    event_id,
+                    COALESCE(room_version, '1'),
+                    json,
+                    state_events.event_id IS NOT NULL,
+                    event_auth.event_id IS NOT NULL
+                FROM rejections
+                INNER JOIN event_json USING (event_id)
+                LEFT JOIN rooms USING (room_id)
+                LEFT JOIN state_events USING (event_id)
+                LEFT JOIN event_auth USING (event_id)
+                WHERE event_id > ?
+                ORDER BY event_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_event_id, batch_size,))
+
+            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
+
+        results = await self.db_pool.runInteraction(
+            desc="_rejected_events_metadata_get", func=get_rejected_events
+        )
+
+        if not results:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+            return 0
+
+        state_events = []
+        auth_events = []
+        for event_id, room_version, event_json, has_state, has_event_auth in results:
+            last_event_id = event_id
+
+            if has_state and has_event_auth:
+                continue
+
+            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
+            if not room_version_obj:
+                # We no longer support this room version, so we just ignore the
+                # events entirely.
+                logger.info(
+                    "Ignoring event with unknown room version %r: %r",
+                    room_version,
+                    event_id,
+                )
+                continue
+
+            event = make_event_from_dict(event_json, room_version_obj)
+
+            if not event.is_state():
+                continue
+
+            if not has_state:
+                state_events.append(
+                    {
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    }
+                )
+
+            if not has_event_auth:
+                for auth_id in event.auth_event_ids():
+                    auth_events.append(
+                        {
+                            "room_id": event.room_id,
+                            "event_id": event.event_id,
+                            "auth_id": auth_id,
+                        }
+                    )
+
+        if state_events:
+            await self.db_pool.simple_insert_many(
+                table="state_events",
+                values=state_events,
+                desc="_rejected_events_metadata_state_events",
+            )
+
+        if auth_events:
+            await self.db_pool.simple_insert_many(
+                table="event_auth",
+                values=auth_events,
+                desc="_rejected_events_metadata_event_auth",
+            )
+
+        await self.db_pool.updates._background_update_progress(
+            "rejected_events_metadata", {"last_event_id": last_event_id}
+        )
+
+        if len(results) < batch_size:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+
+        return len(results)
+
+    async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
+        """A background updates that iterates over all rooms and generates the
+        chain cover index for them.
+        """
+
+        current_room_id = progress.get("current_room_id", "")
+
+        # Where we've processed up to in the room, defaults to the start of the
+        # room.
+        last_depth = progress.get("last_depth", -1)
+        last_stream = progress.get("last_stream", -1)
+
+        result = await self.db_pool.runInteraction(
+            "_chain_cover_index",
+            self._calculate_chain_cover_txn,
+            current_room_id,
+            last_depth,
+            last_stream,
+            batch_size,
+            single_room=False,
+        )
+
+        finished = result.processed_count == 0
+
+        total_rows_processed = result.processed_count
+        current_room_id = result.room_id
+        last_depth = result.depth
+        last_stream = result.stream
+
+        for room_id, (depth, stream) in result.finished_room_map.items():
+            # If we've done all the events in the room we flip the
+            # `has_auth_chain_index` in the DB. Note that its possible for
+            # further events to be persisted between the above and setting the
+            # flag without having the chain cover calculated for them. This is
+            # fine as a) the code gracefully handles these cases and b) we'll
+            # calculate them below.
+
+            await self.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
+                desc="_chain_cover_index",
+            )
+
+            # Handle any events that might have raced with us flipping the
+            # bit above.
+            result = await self.db_pool.runInteraction(
+                "_chain_cover_index",
+                self._calculate_chain_cover_txn,
+                room_id,
+                depth,
+                stream,
+                batch_size=None,
+                single_room=True,
+            )
+
+            total_rows_processed += result.processed_count
+
+        if finished:
+            await self.db_pool.updates._end_background_update("chain_cover")
+            return total_rows_processed
+
+        await self.db_pool.updates._background_update_progress(
+            "chain_cover",
+            {
+                "current_room_id": current_room_id,
+                "last_depth": last_depth,
+                "last_stream": last_stream,
+            },
+        )
+
+        return total_rows_processed
+
+    def _calculate_chain_cover_txn(
+        self,
+        txn: Cursor,
+        last_room_id: str,
+        last_depth: int,
+        last_stream: int,
+        batch_size: Optional[int],
+        single_room: bool,
+    ) -> _CalculateChainCover:
+        """Calculate the chain cover for `batch_size` events, ordered by
+        `(room_id, depth, stream)`.
+
+        Args:
+            txn,
+            last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+                tuple to fetch results after.
+            batch_size: The maximum number of events to process. If None then
+                no limit.
+            single_room: Whether to calculate the index for just the given
+                room.
+        """
+
+        # Get the next set of events in the room (that we haven't already
+        # computed chain cover for). We do this in topological order.
+
+        # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+        # comparison, but that is not supported on older SQLite versions
+        tuple_clause, tuple_args = make_tuple_comparison_clause(
+            self.database_engine,
+            [
+                ("events.room_id", last_room_id),
+                ("topological_ordering", last_depth),
+                ("stream_ordering", last_stream),
+            ],
+        )
+
+        extra_clause = ""
+        if single_room:
+            extra_clause = "AND events.room_id = ?"
+            tuple_args.append(last_room_id)
+
+        sql = """
+            SELECT
+                event_id, state_events.type, state_events.state_key,
+                topological_ordering, stream_ordering,
+                events.room_id
+            FROM events
+            INNER JOIN state_events USING (event_id)
+            LEFT JOIN event_auth_chains USING (event_id)
+            LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+            WHERE event_auth_chains.event_id IS NULL
+                AND event_auth_chain_to_calculate.event_id IS NULL
+                AND %(tuple_cmp)s
+                %(extra)s
+            ORDER BY events.room_id, topological_ordering, stream_ordering
+            %(limit)s
+        """ % {
+            "tuple_cmp": tuple_clause,
+            "limit": "LIMIT ?" if batch_size is not None else "",
+            "extra": extra_clause,
+        }
+
+        if batch_size is not None:
+            tuple_args.append(batch_size)
+
+        txn.execute(sql, tuple_args)
+        rows = txn.fetchall()
+
+        # Put the results in the necessary format for
+        # `_add_chain_cover_index`
+        event_to_room_id = {row[0]: row[5] for row in rows}
+        event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+        # Calculate the new last position we've processed up to.
+        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
+        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
+        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+
+        # Map from room_id to last depth/stream_ordering processed for the room,
+        # excluding the last room (which we're likely still processing). We also
+        # need to include the room passed in if it's not included in the result
+        # set (as we then know we've processed all events in said room).
+        #
+        # This is the set of rooms that we can now safely flip the
+        # `has_auth_chain_index` bit for.
+        finished_rooms = {
+            row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+        }
+        if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+            finished_rooms[last_room_id] = (last_depth, last_stream)
+
+        count = len(rows)
+
+        # We also need to fetch the auth events for them.
+        auth_events = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth",
+            column="event_id",
+            iterable=event_to_room_id,
+            keyvalues={},
+            retcols=("event_id", "auth_id"),
+        )
+
+        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        for row in auth_events:
+            event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+        # Calculate and persist the chain cover index for this set of events.
+        #
+        # Annoyingly we need to gut wrench into the persit event store so that
+        # we can reuse the function to calculate the chain cover for rooms.
+        PersistEventsStore._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+        return _CalculateChainCover(
+            room_id=new_last_room_id,
+            depth=new_last_depth,
+            stream=new_last_stream,
+            processed_count=count,
+            finished_room_map=finished_rooms,
+        )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4732685f6e..71d823be72 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="events",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_stream_seq",
                 writers=hs.config.worker.writers.events,
             )
@@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="backfill",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_backfill_stream_seq",
                 positive=False,
                 writers=hs.config.worker.writers.events,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4b2f224718..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_local_media_before(
         self, before_ts: int, size_gt: int, keep_profiles: bool,
-    ) -> Optional[List[str]]:
+    ) -> List[str]:
 
         # to find files that have never been accessed (last_access_ts IS NULL)
         # compare with `created_ts`
@@ -416,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_origin = ? AND media_id = ?"
             )
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (time_ms, media_origin, media_id)
@@ -429,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_id = ?"
             )
 
-            txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+            txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
         return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
@@ -556,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache", _delete_url_cache_txn
@@ -585,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_media_txn(txn):
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
             sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 0e25ca3d7a..54ef0f1f54 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def set_profile_avatar_url(
-        self, user_localpart: str, new_avatar_url: str
+        self, user_localpart: str, new_avatar_url: Optional[str]
     ) -> None:
         await self.db_pool.simple_update_one(
             table="profiles",
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         )
 
         # Update backward extremeties
-        txn.executemany(
+        txn.execute_batch(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
             [(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 77ba9d819e..bc7621b8d6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -17,14 +17,13 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
 
-from canonicaljson import encode_canonical_json
-
 from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
+from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
@@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore):
                     "device_display_name": device_display_name,
                     "ts": pushkey_ts,
                     "lang": lang,
-                    "data": bytearray(encode_canonical_json(data)),
+                    "data": json_encoder.encode(data),
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1e7949a323..e4843a202c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,15 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
 from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet import defer
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
-    """This is an abstract base class where subclasses must implement
-    `get_max_receipt_stream_id` which can be called in the initializer.
-    """
-
+class ReceiptsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_receipts = (
+                self._instance_name in hs.config.worker.writers.receipts
+            )
+
+            self._receipts_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="receipts",
+                instance_name=self._instance_name,
+                tables=[("receipts_linearized", "instance_name", "stream_id")],
+                sequence_name="receipts_sequence",
+                writers=hs.config.worker.writers.receipts,
+            )
+        else:
+            self._can_write_to_receipts = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.receipts:
+                self._receipts_id_gen = StreamIdGenerator(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+            else:
+                self._receipts_id_gen = SlavedIdTracker(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+
         super().__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
-    @abc.abstractmethod
     def get_max_receipt_stream_id(self):
         """Get the current max stream ID for receipts stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._receipts_id_gen.get_current_token()
 
     @cached()
     async def get_users_with_read_receipts_in_room(self, room_id):
@@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
-
-class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
+    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+        self.get_receipts_for_user.invalidate((user_id, receipt_type))
+        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self.get_last_receipt_event_id_for_user.invalidate(
+            (user_id, room_id, receipt_type)
         )
+        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+        self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == ReceiptsStream.NAME:
+            self._receipts_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.invalidate_caches_for_receipt(
+                    row.room_id, row.receipt_type, row.user_id
+                )
+                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def insert_linearized_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
         """
+        assert self._can_write_to_receipts
+
         res = self.db_pool.simple_select_one_txn(
             txn,
             table="events",
@@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
                     )
                     return None
 
-        txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
-        txn.call_after(
-            self._invalidate_get_users_with_receipts_in_room,
-            room_id,
-            receipt_type,
-            user_id,
-        )
-        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
-        # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(
-            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
+            self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
         )
 
         txn.call_after(
             self._receipts_stream_cache.entity_has_changed, room_id, stream_id
         )
 
-        txn.call_after(
-            self.get_last_receipt_event_id_for_user.invalidate,
-            (user_id, room_id, receipt_type),
-        )
-
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_linearized",
@@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
         Automatically does conversion between linearized and graph
         representations.
         """
+        assert self._can_write_to_receipts
+
         if not event_ids:
             return None
 
@@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     async def insert_graph_receipt(
         self, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         return await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
@@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     def insert_graph_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
         txn.call_after(
             self._invalidate_get_users_with_receipts_in_room,
@@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "data": json_encoder.encode(data),
             },
         )
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8d05288ed4..585b4049d6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1104,7 +1104,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 FROM user_threepids
             """
 
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+            txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4650d0689b..a9fcb5f59c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -16,7 +16,6 @@
 
 import collections
 import logging
-import re
 from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
@@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
+from synapse.util.stringutils import MXC_REGEX
 
 logger = logging.getLogger(__name__)
 
@@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore):
         return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
-            retcols=("room_id", "is_public", "creator"),
+            retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
             desc="get_room",
             allow_none=True,
         )
@@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore):
             The local and remote media as a lists of tuples where the key is
             the hostname and the value is the media ID.
         """
-        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
         sql = """
             SELECT stream_ordering, json FROM events
             JOIN event_json USING (room_id, event_id)
@@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore):
                 for url in (content_url, thumbnail_url):
                     if not url:
                         continue
-                    matches = mxc_re.match(url)
+                    matches = MXC_REGEX.match(url)
                     if matches:
                         hostname = matches.group(1)
                         media_id = matches.group(2)
@@ -1166,6 +1164,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         # It's overridden by RoomStore for the synapse master.
         raise NotImplementedError()
 
+    async def has_auth_chain_index(self, room_id: str) -> bool:
+        """Check if the room has (or can have) a chain cover index.
+
+        Defaults to True if we don't have an entry in `rooms` table nor any
+        events for the room.
+        """
+
+        has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcol="has_auth_chain_index",
+            desc="has_auth_chain_index",
+            allow_none=True,
+        )
+
+        if has_auth_chain_index:
+            return True
+
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        max_ordering = await self.db_pool.simple_select_one_onecol(
+            table="events",
+            keyvalues={"room_id": room_id},
+            retcol="MAX(stream_ordering)",
+            allow_none=True,
+            desc="upsert_room_on_join",
+        )
+
+        return max_ordering is None
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1179,12 +1208,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         Called when we join a room over federation, and overwrites any room version
         currently in the table.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="upsert_room_on_join",
             table="rooms",
             keyvalues={"room_id": room_id},
             values={"room_version": room_version.identifier},
-            insertion_values={"is_public": False, "creator": ""},
+            insertion_values={
+                "is_public": False,
+                "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
+            },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
             lock=False,
@@ -1219,6 +1257,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         "creator": room_creator_user_id,
                         "is_public": is_public,
                         "room_version": room_version.identifier,
+                        "has_auth_chain_index": True,
                     },
                 )
                 if is_public:
@@ -1247,6 +1286,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         When we receive an invite or any other event over federation that may relate to a room
         we are not in, store the version of the room if we don't already know the room version.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="maybe_store_room_on_outlier_membership",
             table="rooms",
@@ -1256,6 +1300,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 "room_version": room_version.identifier,
                 "is_public": False,
                 "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
             },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             "max_stream_id_exclusive", self._stream_order_on_start + 1
         )
 
-        INSERT_CLUMP_SIZE = 1000
-
         def add_membership_profile_txn(txn):
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
             """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
+            txn.execute_batch(to_update_sql, to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
new file mode 100644
index 0000000000..de57645019
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE access_tokens DROP COLUMN last_used;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
new file mode 100644
index 0000000000..ee0e3521bf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Dropping last_used column from access_tokens table.
+
+CREATE TABLE access_tokens2 (
+    id BIGINT PRIMARY KEY, 
+    user_id TEXT NOT NULL, 
+    device_id TEXT, 
+    token TEXT NOT NULL,
+    valid_until_ms BIGINT,
+    puppets_user_id TEXT,
+    last_validated BIGINT,
+    UNIQUE(token) 
+);
+
+INSERT INTO access_tokens2(id, user_id, device_id, token)
+    SELECT id, user_id, device_id, token FROM access_tokens;
+
+DROP TABLE access_tokens;
+ALTER TABLE access_tokens2 RENAME TO access_tokens;
+
+CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id);
+
+
+-- Re-adding foreign key reference in event_txn_id table
+
+CREATE TABLE event_txn_id2 (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    token_id BIGINT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (token_id)
+        REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts)
+    SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id;
+
+DROP TABLE event_txn_id;
+ALTER TABLE event_txn_id2 RENAME TO event_txn_id;
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
new file mode 100644
index 0000000000..9c95646281
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5828, 'rejected_events_metadata', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
         # { "ignored_users": "@someone:example.org": {} }
         ignored_users = content.get("ignored_users", {})
         if isinstance(ignored_users, dict) and ignored_users:
-            cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+            cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
 
     # Add indexes after inserting data for efficiency.
     logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
new file mode 100644
index 0000000000..729196cfd5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
@@ -0,0 +1,52 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- See docs/auth_chain_difference_algorithm.md
+
+CREATE TABLE event_auth_chains (
+  event_id TEXT PRIMARY KEY,
+  chain_id BIGINT NOT NULL,
+  sequence_number BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
+
+
+CREATE TABLE event_auth_chain_links (
+  origin_chain_id BIGINT NOT NULL,
+  origin_sequence_number BIGINT NOT NULL,
+
+  target_chain_id BIGINT NOT NULL,
+  target_sequence_number BIGINT NOT NULL
+);
+
+
+CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
+
+
+-- Events that we have persisted but not calculated auth chains for,
+-- e.g. out of band memberships (where we don't have the auth chain)
+CREATE TABLE event_auth_chain_to_calculate (
+  event_id TEXT PRIMARY KEY,
+  room_id TEXT NOT NULL,
+  type TEXT NOT NULL,
+  state_key TEXT NOT NULL
+);
+
+CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
+
+
+-- Whether we've calculated the above index for a room.
+ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
new file mode 100644
index 0000000000..e8a035bbeb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
new file mode 100644
index 0000000000..64ab696cfe
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS account_data_max_stream_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
new file mode 100644
index 0000000000..fb71b360a0
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS cache_invalidation_stream;
diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
new file mode 100644
index 0000000000..fe3dca71dd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (5906, 'chain_cover', '{}', 'rejected_events_metadata');
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
new file mode 100644
index 0000000000..46abf8d562
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
+ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
+ALTER TABLE account_data ADD COLUMN instance_name TEXT;
+
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
new file mode 100644
index 0000000000..4a6e6c74f5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
@@ -0,0 +1,32 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
+
+-- We need to take the max across all the account_data tables as they share the
+-- ID generator
+SELECT setval('account_data_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
+    )
+));
+
+CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
+
+SELECT setval('receipts_sequence', (
+    SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
+));
diff --git a/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
new file mode 100644
index 0000000000..9f2b5ebc5a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We incorrectly populated these, so we delete them and let the
+-- MultiWriterIdGenerator repopulate it.
+DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data';
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..871af64b11 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
@@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 9f120d3cb6..50067eabfc 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
         )
         return {row["tag"]: db_to_json(row["content"]) for row in rows}
 
-
-class TagsStore(TagsWorkerStore):
     async def add_tag_to_room(
         self, user_id: str, room_id: str, tag: str, content: JsonDict
     ) -> int:
@@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         def add_tag_txn(txn, next_id):
@@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
 
         def remove_tag_txn(txn, next_id):
             sql = (
@@ -250,21 +251,12 @@ class TagsStore(TagsWorkerStore):
             room_id: The ID of the room.
             next_id: The the revision to advance to.
         """
+        assert self._can_write_to_account_data
 
         txn.call_after(
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
         )
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
-        update_max_id_sql = (
-            "UPDATE account_data_max_stream_id"
-            " SET stream_id = ?"
-            " WHERE stream_id < ?"
-        )
-        txn.execute(update_max_id_sql, (next_id, next_id))
-
         update_sql = (
             "UPDATE room_tags_revisions"
             " SET stream_id = ?"
@@ -288,3 +280,7 @@ class TagsStore(TagsWorkerStore):
                 # which stream_id ends up in the table, as long as it is higher
                 # than the id that the client has.
                 pass
+
+
+class TagsStore(TagsWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 59207cadd4..cea595ff19 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore):
         txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
     ) -> List[str]:
         q = """
-            SELECT destination FROM destinations
-                WHERE destination IN (
-                    SELECT destination FROM destination_rooms
-                        WHERE destination_rooms.stream_ordering >
-                            destinations.last_successful_stream_ordering
-                )
-                AND destination > ?
-                AND (
-                    retry_last_ts IS NULL OR
-                    retry_last_ts + retry_interval < ?
-                )
-                ORDER BY destination
-                LIMIT 25
+            SELECT DISTINCT destination FROM destinations
+            INNER JOIN destination_rooms USING (destination)
+                WHERE
+                    stream_ordering > last_successful_stream_ordering
+                    AND destination > ?
+                    AND (
+                        retry_last_ts IS NULL OR
+                        retry_last_ts + retry_interval < ?
+                    )
+                    ORDER BY destination
+                    LIMIT 25
         """
         txn.execute(
             q,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
         logger.info("[purge] removing redundant state groups")
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups WHERE id = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 01efb2cabb..566ea19bae 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -35,9 +35,6 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-# XXX: If you're about to bump this to 59 (or higher) please create an update
-# that drops the unused `cache_invalidation_stream` table, as per #7436!
-# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
 SCHEMA_VERSION = 59
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 133c0e7a28..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
 import heapq
 import logging
 import threading
-from collections import deque
+from collections import OrderedDict
 from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 import attr
-from typing_extensions import Deque
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()  # type: Deque[int]
+
+        # We use this as an ordered set, as we want to efficiently append items,
+        # remove items and get the first item. Since we insert IDs in order, the
+        # insertion ordering will ensure its in the correct ordering.
+        #
+        # The key and values are the same, but we never look at the values.
+        self._unfinished_ids = OrderedDict()  # type: OrderedDict[int, int]
 
     def get_next(self):
         """
@@ -113,7 +118,7 @@ class StreamIdGenerator:
             self._current += self._step
             next_id = self._current
 
-            self._unfinished_ids.append(next_id)
+            self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
                 yield next_id
             finally:
                 with self._lock:
-                    self._unfinished_ids.remove(next_id)
+                    self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -140,7 +145,7 @@ class StreamIdGenerator:
             self._current += n * self._step
 
             for next_id in next_ids:
-                self._unfinished_ids.append(next_id)
+                self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
             finally:
                 with self._lock:
                     for next_id in next_ids:
-                        self._unfinished_ids.remove(next_id)
+                        self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -162,7 +167,7 @@ class StreamIdGenerator:
         """
         with self._lock:
             if self._unfinished_ids:
-                return self._unfinished_ids[0] - self._step
+                return next(iter(self._unfinished_ids)) - self._step
 
             return self._current
 
@@ -186,11 +191,12 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
-        stream_name: A name for the stream.
+        stream_name: A name for the stream, for use in the `stream_positions`
+            table. (Does not need to be the same as the replication stream name)
         instance_name: The name of this instance.
-        table: Database table associated with stream.
-        instance_column: Column that stores the row's writer's instance name
-        id_column: Column that stores the stream ID.
+        tables: List of tables associated with the stream. Tuple of table
+            name, column name that stores the writer's instance name, and
+            column name that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
         writers: A list of known writers to use to populate current positions
@@ -206,9 +212,7 @@ class MultiWriterIdGenerator:
         db: DatabasePool,
         stream_name: str,
         instance_name: str,
-        table: str,
-        instance_column: str,
-        id_column: str,
+        tables: List[Tuple[str, str, str]],
         sequence_name: str,
         writers: List[str],
         positive: bool = True,
@@ -260,15 +264,20 @@ class MultiWriterIdGenerator:
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
         # We check that the table and sequence haven't diverged.
-        self._sequence_gen.check_consistency(
-            db_conn, table=table, id_column=id_column, positive=positive
-        )
+        for table, _, id_column in tables:
+            self._sequence_gen.check_consistency(
+                db_conn,
+                table=table,
+                id_column=id_column,
+                stream_name=stream_name,
+                positive=positive,
+            )
 
         # This goes and fills out the above state from the database.
-        self._load_current_ids(db_conn, table, instance_column, id_column)
+        self._load_current_ids(db_conn, tables)
 
     def _load_current_ids(
-        self, db_conn, table: str, instance_column: str, id_column: str
+        self, db_conn, tables: List[Tuple[str, str, str]],
     ):
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
@@ -306,17 +315,22 @@ class MultiWriterIdGenerator:
             # We add a GREATEST here to ensure that the result is always
             # positive. (This can be a problem for e.g. backfill streams where
             # the server has never backfilled).
-            sql = """
-                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
-                FROM %(table)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "agg": "MAX" if self._positive else "-MIN",
-            }
-            cur.execute(sql)
-            (stream_id,) = cur.fetchone()
-            self._persisted_upto_position = stream_id
+            max_stream_id = 1
+            for table, _, id_column in tables:
+                sql = """
+                    SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                    FROM %(table)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "agg": "MAX" if self._positive else "-MIN",
+                }
+                cur.execute(sql)
+                (stream_id,) = cur.fetchone()
+
+                max_stream_id = max(max_stream_id, stream_id)
+
+            self._persisted_upto_position = max_stream_id
         else:
             # If we have a min_stream_id then we pull out everything greater
             # than it from the DB so that we can prefill
@@ -329,21 +343,28 @@ class MultiWriterIdGenerator:
             # stream positions table before restart (or the stream position
             # table otherwise got out of date).
 
-            sql = """
-                SELECT %(instance)s, %(id)s FROM %(table)s
-                WHERE ? %(cmp)s %(id)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "instance": instance_column,
-                "cmp": "<=" if self._positive else ">=",
-            }
-            cur.execute(sql, (min_stream_id * self._return_factor,))
-
             self._persisted_upto_position = min_stream_id
 
+            rows = []
+            for table, instance_column, id_column in tables:
+                sql = """
+                    SELECT %(instance)s, %(id)s FROM %(table)s
+                    WHERE ? %(cmp)s %(id)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "instance": instance_column,
+                    "cmp": "<=" if self._positive else ">=",
+                }
+                cur.execute(sql, (min_stream_id * self._return_factor,))
+
+                rows.extend(cur)
+
+            # Sort so that we handle rows in order for each instance.
+            rows.sort()
+
             with self._lock:
-                for (instance, stream_id,) in cur:
+                for (instance, stream_id,) in rows:
                     stream_id = self._return_factor * stream_id
                     self._add_persisted_position(stream_id)
 
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 4386b6101e..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -15,9 +15,8 @@
 import abc
 import logging
 import threading
-from typing import Callable, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
 
-from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
@@ -25,6 +24,9 @@ from synapse.storage.engines import (
 )
 from synapse.storage.types import Connection, Cursor
 
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
+
 logger = logging.getLogger(__name__)
 
 
@@ -43,6 +45,21 @@ and run the following SQL:
 See docs/postgres.md for more information.
 """
 
+_INCONSISTENT_STREAM_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated stream position
+of '%(stream_name)s' in the 'stream_positions' table.
+
+This is likely a programming error and should be reported at
+https://github.com/matrix-org/synapse.
+
+A temporary workaround to fix this error is to shut down Synapse (including
+any and all workers) and run the following SQL:
+
+    DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
+
+This will need to be done every time the server is restarted.
+"""
+
 
 class SequenceGenerator(metaclass=abc.ABCMeta):
     """A class which generates a unique sequence of integers"""
@@ -53,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        """Get the next `n` IDs in the sequence"""
+        ...
+
+    @abc.abstractmethod
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
+        stream_name: Optional[str] = None,
         positive: bool = True,
     ):
         """Should be called during start up to test that the current value of
         the sequence is greater than or equal to the maximum ID in the table.
 
-        This is to handle various cases where the sequence value can get out
-        of sync with the table, e.g. if Synapse gets rolled back to a previous
+        This is to handle various cases where the sequence value can get out of
+        sync with the table, e.g. if Synapse gets rolled back to a previous
         version and the rolled forwards again.
+
+        If a stream name is given then this will check that any value in the
+        `stream_positions` table is less than or equal to the current sequence
+        value. If it isn't then it's likely that streams have been crossed
+        somewhere (e.g. two ID generators have the same stream name).
         """
         ...
 
@@ -88,11 +116,15 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
+        stream_name: Optional[str] = None,
         positive: bool = True,
     ):
+        """See SequenceGenerator.check_consistency for docstring.
+        """
+
         txn = db_conn.cursor(txn_name="sequence.check_consistency")
 
         # First we get the current max ID from the table.
@@ -116,6 +148,18 @@ class PostgresSequenceGenerator(SequenceGenerator):
             "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
         )
         last_value, is_called = txn.fetchone()
+
+        # If we have an associated stream check the stream_positions table.
+        max_in_stream_positions = None
+        if stream_name:
+            txn.execute(
+                "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
+                (stream_name,),
+            )
+            row = txn.fetchone()
+            if row:
+                max_in_stream_positions = row[0]
+
         txn.close()
 
         # If `is_called` is False then `last_value` is actually the value that
@@ -136,6 +180,14 @@ class PostgresSequenceGenerator(SequenceGenerator):
                 % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
             )
 
+        # If we have values in the stream positions table then they have to be
+        # less than or equal to `last_value`
+        if max_in_stream_positions and max_in_stream_positions > last_value:
+            raise IncorrectDatabaseSetup(
+                _INCONSISTENT_STREAM_ERROR
+                % {"seq": self._sequence_name, "stream_name": stream_name}
+            )
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
@@ -172,8 +224,24 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            first_id = self._current_max_id + 1
+            self._current_max_id += n
+            return [first_id + i for i in range(n)]
+
     def check_consistency(
-        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+        self,
+        db_conn: Connection,
+        table: str,
+        id_column: str,
+        stream_name: Optional[str] = None,
+        positive: bool = True,
     ):
         # There is nothing to do for in memory sequences
         pass
diff --git a/synapse/types.py b/synapse/types.py
index c7d4e95809..eafe729dfe 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -37,6 +37,7 @@ from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
@@ -257,8 +258,13 @@ class DomainSpecificString(
 
     @classmethod
     def is_valid(cls: Type[DS], s: str) -> bool:
+        """Parses the input string and attempts to ensure it is valid."""
         try:
-            cls.from_string(s)
+            obj = cls.from_string(s)
+            # Apply additional validation to the domain. This is only done
+            # during  is_valid (and not part of from_string) since it is
+            # possible for invalid data to exist in room-state, etc.
+            parse_and_validate_server_name(obj.domain)
             return True
         except Exception:
             return False
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..6ef2b008a4 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import heapq
 from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    Mapping,
+    Sequence,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from synapse.types import Collection
 
 T = TypeVar("T")
 
@@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
     If the input is empty, no chunks are returned.
     """
     return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+    nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+    """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+    For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+    """
+
+    # This is implemented by Kahn's algorithm.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph = {}  # type: Dict[T, Set[T]]
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in edges:
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+    heapq.heapify(zero_degree)
+
+    while zero_degree:
+        node = heapq.heappop(zero_degree)
+        yield node
+
+        for edge in reverse_graph.get(node, []):
+            if edge in degree_map:
+                degree_map[edge] -= 1
+                if degree_map[edge] == 0:
+                    heapq.heappush(zero_degree, edge)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..f8038bf861 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -18,6 +18,7 @@ import random
 import re
 import string
 from collections.abc import Iterable
+from typing import Optional, Tuple
 
 from synapse.api.errors import Codes, SynapseError
 
@@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
 client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
 
+# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
+# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
+# says "there is no grammar for media ids"
+#
+# The server_name part of this is purposely lax: use parse_and_validate_mxc for
+# additional validation.
+#
+MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
+
 # random_string and random_string_with_symbols are used for a range of things,
 # some cryptographically important, some less so. We use SystemRandom to make sure
 # we get cryptographically-secure randoms.
@@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret):
         )
 
 
+def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == "]":
+            # ipv6 literal, hopefully
+            return server_name, None
+
+        domain_port = server_name.rsplit(":", 1)
+        domain = domain_port[0]
+        port = int(domain_port[1]) if domain_port[1:] else None
+        return domain, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+
+
+def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts and do some basic validation.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    host, port = parse_server_name(server_name)
+
+    # these tests don't need to be bulletproof as we'll find out soon enough
+    # if somebody is giving us invalid data. What we *do* need is to be sure
+    # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+    # look for ipv6 literals
+    if host[0] == "[":
+        if host[-1] != "]":
+            raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
+        return host, port
+
+    # otherwise it should only be alphanumerics.
+    if not VALID_HOST_REGEX.match(host):
+        raise ValueError(
+            "Server name '%s' contains invalid characters" % (server_name,)
+        )
+
+    return host, port
+
+
+def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
+    """Parse the given string as an MXC URI
+
+    Checks that the "server name" part is a valid server name
+
+    Args:
+        mxc: the (alleged) MXC URI to be checked
+    Returns:
+        hostname, port, media id
+    Raises:
+        ValueError if the URI cannot be parsed
+    """
+    m = MXC_REGEX.match(mxc)
+    if not m:
+        raise ValueError("mxc URI %r did not match expected format" % (mxc,))
+    server_name = m.group(1)
+    media_id = m.group(2)
+    host, port = parse_and_validate_server_name(server_name)
+    return host, port, media_id
+
+
 def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
     """If iterable has maxitems or fewer, return the stringification of a list
     containing those items.
@@ -75,3 +167,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
     if len(items) <= maxitems:
         return str(items)
     return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+    """Convert a string representation of truth to True or False
+
+    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+    are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
+    'val' is anything else.
+
+    This is lifted from distutils.util.strtobool, with the exception that it actually
+    returns a bool, rather than an int.
+    """
+    val = val.lower()
+    if val in ("y", "yes", "t", "true", "on", "1"):
+        return True
+    elif val in ("n", "no", "f", "false", "off", "0"):
+        return False
+    else:
+        raise ValueError("invalid truth value %r" % (val,))