summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-x.buildkite/scripts/create_postgres_db.py1
-rw-r--r--.gitignore1
-rw-r--r--INSTALL.md2
-rw-r--r--README.rst2
-rw-r--r--changelog.d/8868.misc1
-rw-r--r--changelog.d/8932.feature1
-rw-r--r--changelog.d/8939.misc1
-rw-r--r--changelog.d/8948.feature1
-rw-r--r--changelog.d/8984.feature1
-rw-r--r--changelog.d/9015.feature1
-rw-r--r--changelog.d/9016.misc1
-rw-r--r--changelog.d/9017.feature1
-rw-r--r--changelog.d/9018.misc1
-rw-r--r--changelog.d/9023.bugfix1
-rw-r--r--changelog.d/9024.feature1
-rw-r--r--changelog.d/9025.misc1
-rw-r--r--changelog.d/9028.bugfix1
-rw-r--r--changelog.d/9030.misc1
-rw-r--r--changelog.d/9031.misc1
-rw-r--r--changelog.d/9033.misc1
-rw-r--r--changelog.d/9035.doc1
-rw-r--r--changelog.d/9036.feature1
-rw-r--r--changelog.d/9038.misc1
-rw-r--r--changelog.d/9039.removal1
-rw-r--r--changelog.d/9040.doc1
-rw-r--r--changelog.d/9041.misc1
-rw-r--r--changelog.d/9042.feature1
-rw-r--r--changelog.d/9043.feature1
-rw-r--r--changelog.d/9044.feature1
-rw-r--r--changelog.d/9051.bugfix1
-rw-r--r--changelog.d/9053.bugfix1
-rw-r--r--changelog.d/9054.bugfix1
-rw-r--r--changelog.d/9055.misc1
-rw-r--r--changelog.d/9057.doc1
-rw-r--r--changelog.d/9058.misc1
-rw-r--r--changelog.d/9059.bugfix1
-rw-r--r--changelog.d/9063.misc1
-rw-r--r--changelog.d/9067.feature1
-rw-r--r--changelog.d/9068.feature1
-rw-r--r--changelog.d/9069.misc1
-rw-r--r--changelog.d/9070.bugfix1
-rw-r--r--changelog.d/9080.misc1
-rw-r--r--changelog.d/9081.feature1
-rw-r--r--demo/webserver.py59
-rw-r--r--docs/admin_api/user_admin_api.rst25
-rw-r--r--docs/auth_chain_diff.dot32
-rw-r--r--docs/auth_chain_diff.dot.pngbin0 -> 42427 bytes
-rw-r--r--docs/auth_chain_difference_algorithm.md108
-rw-r--r--docs/openid.md4
-rw-r--r--docs/sample_config.yaml25
-rw-r--r--docs/systemd-with-workers/README.md2
-rw-r--r--mypy.ini1
-rwxr-xr-xscripts/synapse_port_db27
-rw-r--r--stubs/frozendict.pyi11
-rw-r--r--stubs/sortedcontainers/sorteddict.pyi6
-rw-r--r--stubs/txredisapi.pyi2
-rw-r--r--synapse/api/auth.py7
-rw-r--r--synapse/api/room_versions.py32
-rw-r--r--synapse/app/_base.py150
-rw-r--r--synapse/app/generic_worker.py22
-rw-r--r--synapse/app/homeserver.py76
-rw-r--r--synapse/config/_util.py2
-rw-r--r--synapse/config/sso.py27
-rw-r--r--synapse/config/workers.py10
-rw-r--r--synapse/events/utils.py16
-rw-r--r--synapse/federation/federation_server.py21
-rw-r--r--synapse/handlers/auth.py95
-rw-r--r--synapse/handlers/cas_handler.py38
-rw-r--r--synapse/handlers/deactivate_account.py18
-rw-r--r--synapse/handlers/devicemessage.py48
-rw-r--r--synapse/handlers/oidc_handler.py320
-rw-r--r--synapse/handlers/profile.py12
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/saml_handler.py28
-rw-r--r--synapse/handlers/sso.py136
-rw-r--r--synapse/handlers/ui_auth/__init__.py15
-rw-r--r--synapse/http/__init__.py15
-rw-r--r--synapse/http/matrixfederationclient.py10
-rw-r--r--synapse/http/site.py18
-rw-r--r--synapse/logging/context.py50
-rw-r--r--synapse/notifier.py39
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py12
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py40
-rw-r--r--synapse/replication/tcp/handler.py9
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html28
-rw-r--r--synapse/rest/admin/users.py29
-rw-r--r--synapse/rest/client/v1/login.py89
-rw-r--r--synapse/rest/client/v2_alpha/account.py44
-rw-r--r--synapse/rest/client/v2_alpha/auth.py53
-rw-r--r--synapse/rest/client/v2_alpha/devices.py12
-rw-r--r--synapse/rest/client/v2_alpha/keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/register.py21
-rw-r--r--synapse/rest/synapse/client/pick_idp.py82
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/static/client/login/style.css5
-rw-r--r--synapse/storage/database.py30
-rw-r--r--synapse/storage/databases/main/__init__.py33
-rw-r--r--synapse/storage/databases/main/account_data.py157
-rw-r--r--synapse/storage/databases/main/client_ips.py54
-rw-r--r--synapse/storage/databases/main/deviceinbox.py252
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py88
-rw-r--r--synapse/storage/databases/main/event_federation.py185
-rw-r--r--synapse/storage/databases/main/events.py584
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py124
-rw-r--r--synapse/storage/databases/main/profile.py2
-rw-r--r--synapse/storage/databases/main/room.py51
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite62
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py82
-rw-r--r--synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres25
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql52
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql17
-rw-r--r--synapse/storage/databases/main/tags.py10
-rw-r--r--synapse/storage/prepare_database.py22
-rw-r--r--synapse/util/caches/deferred_cache.py2
-rw-r--r--synapse/util/iterutils.py53
-rw-r--r--synapse/util/metrics.py11
-rw-r--r--tests/config/test_util.py53
-rw-r--r--tests/events/test_utils.py185
-rw-r--r--tests/handlers/test_cas.py2
-rw-r--r--tests/handlers/test_oidc.py83
-rw-r--r--tests/handlers/test_profile.py30
-rw-r--r--tests/handlers/test_saml.py2
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/rest/admin/test_admin.py3
-rw-r--r--tests/rest/admin/test_event_reports.py4
-rw-r--r--tests/rest/admin/test_media.py2
-rw-r--r--tests/rest/admin/test_room.py2
-rw-r--r--tests/rest/admin/test_statistics.py1
-rw-r--r--tests/rest/admin/test_user.py283
-rw-r--r--tests/rest/client/v1/test_login.py280
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/utils.py3
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py26
-rw-r--r--tests/rest/media/v1/test_url_preview.py7
-rw-r--r--tests/storage/test_account_data.py120
-rw-r--r--tests/storage/test_event_chain.py472
-rw-r--r--tests/storage/test_event_federation.py249
-rw-r--r--tests/storage/test_profile.py26
-rw-r--r--tests/test_preview.py11
-rw-r--r--tests/unittest.py28
-rw-r--r--tests/util/caches/test_deferred_cache.py73
-rw-r--r--tests/util/test_itertools.py41
-rw-r--r--tox.ini28
148 files changed, 4789 insertions, 1203 deletions
diff --git a/.buildkite/scripts/create_postgres_db.py b/.buildkite/scripts/create_postgres_db.py
index df6082b0ac..956339de5c 100755
--- a/.buildkite/scripts/create_postgres_db.py
+++ b/.buildkite/scripts/create_postgres_db.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+
 from synapse.storage.engines import create_engine
 
 logger = logging.getLogger("create_postgres_db")
diff --git a/.gitignore b/.gitignore
index 9bb5bdd647..2bccf19997 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,6 +16,7 @@ _trial_temp*/
 # stuff that is likely to exist when you run a server locally
 /*.db
 /*.log
+/*.log.*
 /*.log.config
 /*.pid
 /.python-version
diff --git a/INSTALL.md b/INSTALL.md
index 598ddceb8c..656833637c 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -257,7 +257,7 @@ for a number of platforms.
 
 #### Docker images and Ansible playbooks
 
-There is an offical synapse image available at
+There is an official synapse image available at
 <https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
 the docker-compose file available at [contrib/docker](contrib/docker). Further
 information on this including configuration options is available in the README
diff --git a/README.rst b/README.rst
index 31ae5cc578..9ff375708b 100644
--- a/README.rst
+++ b/README.rst
@@ -243,7 +243,7 @@ Then update the ``users`` table in the database::
 Synapse Development
 ===================
 
-Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
+Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_
 
 Before setting up a development environment for synapse, make sure you have the
 system dependencies (such as the python header files) installed - see
diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc
new file mode 100644
index 0000000000..1a11e30944
--- /dev/null
+++ b/changelog.d/8868.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions for new rooms.
diff --git a/changelog.d/8932.feature b/changelog.d/8932.feature
new file mode 100644
index 0000000000..a1d17394d7
--- /dev/null
+++ b/changelog.d/8932.feature
@@ -0,0 +1 @@
+Remove a user's avatar URL and display name when deactivated with the Admin API.
diff --git a/changelog.d/8939.misc b/changelog.d/8939.misc
new file mode 100644
index 0000000000..bf94135fd5
--- /dev/null
+++ b/changelog.d/8939.misc
@@ -0,0 +1 @@
+Various clean-ups to the structured logging and logging context code.
diff --git a/changelog.d/8948.feature b/changelog.d/8948.feature
new file mode 100644
index 0000000000..3b06cbfa22
--- /dev/null
+++ b/changelog.d/8948.feature
@@ -0,0 +1 @@
+Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users.
diff --git a/changelog.d/8984.feature b/changelog.d/8984.feature
new file mode 100644
index 0000000000..4db629746e
--- /dev/null
+++ b/changelog.d/8984.feature
@@ -0,0 +1 @@
+Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version.
diff --git a/changelog.d/9015.feature b/changelog.d/9015.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9015.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9016.misc b/changelog.d/9016.misc
new file mode 100644
index 0000000000..0d455b17db
--- /dev/null
+++ b/changelog.d/9016.misc
@@ -0,0 +1 @@
+Ensure rejected events get added to some metadata tables.
diff --git a/changelog.d/9017.feature b/changelog.d/9017.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9017.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9018.misc b/changelog.d/9018.misc
new file mode 100644
index 0000000000..bb31eb4a46
--- /dev/null
+++ b/changelog.d/9018.misc
@@ -0,0 +1 @@
+Ignore date-rotated homeserver logs saved to disk.
diff --git a/changelog.d/9023.bugfix b/changelog.d/9023.bugfix
new file mode 100644
index 0000000000..deae64d933
--- /dev/null
+++ b/changelog.d/9023.bugfix
@@ -0,0 +1 @@
+Fix a longstanding issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL.
diff --git a/changelog.d/9024.feature b/changelog.d/9024.feature
new file mode 100644
index 0000000000..073dafbf83
--- /dev/null
+++ b/changelog.d/9024.feature
@@ -0,0 +1 @@
+Improved performance when calculating ignored users in large rooms.
diff --git a/changelog.d/9025.misc b/changelog.d/9025.misc
new file mode 100644
index 0000000000..658f50d853
--- /dev/null
+++ b/changelog.d/9025.misc
@@ -0,0 +1 @@
+Removed an unused column from `access_tokens` table.
diff --git a/changelog.d/9028.bugfix b/changelog.d/9028.bugfix
new file mode 100644
index 0000000000..66666886a4
--- /dev/null
+++ b/changelog.d/9028.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where some caches could grow larger than configured.
diff --git a/changelog.d/9030.misc b/changelog.d/9030.misc
new file mode 100644
index 0000000000..267cfbf9f9
--- /dev/null
+++ b/changelog.d/9030.misc
@@ -0,0 +1 @@
+Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies.
diff --git a/changelog.d/9031.misc b/changelog.d/9031.misc
new file mode 100644
index 0000000000..f43611c385
--- /dev/null
+++ b/changelog.d/9031.misc
@@ -0,0 +1 @@
+Fix running unit tests when optional dependencies are not installed.
diff --git a/changelog.d/9033.misc b/changelog.d/9033.misc
new file mode 100644
index 0000000000..e9a305c0e8
--- /dev/null
+++ b/changelog.d/9033.misc
@@ -0,0 +1 @@
+Allow bumping schema version when using split out state database.
diff --git a/changelog.d/9035.doc b/changelog.d/9035.doc
new file mode 100644
index 0000000000..2a7f0db518
--- /dev/null
+++ b/changelog.d/9035.doc
@@ -0,0 +1 @@
+Corrected a typo in the `systemd-with-workers` documentation.
diff --git a/changelog.d/9036.feature b/changelog.d/9036.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9036.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9038.misc b/changelog.d/9038.misc
new file mode 100644
index 0000000000..5b9e21a1db
--- /dev/null
+++ b/changelog.d/9038.misc
@@ -0,0 +1 @@
+Configure the linters to run on a consistent set of files.
diff --git a/changelog.d/9039.removal b/changelog.d/9039.removal
new file mode 100644
index 0000000000..fb99283ed8
--- /dev/null
+++ b/changelog.d/9039.removal
@@ -0,0 +1 @@
+Remove broken and unmaintained `demo/webserver.py` script.
diff --git a/changelog.d/9040.doc b/changelog.d/9040.doc
new file mode 100644
index 0000000000..5c1f7be781
--- /dev/null
+++ b/changelog.d/9040.doc
@@ -0,0 +1 @@
+Corrected a typo in `INSTALL.md`.
diff --git a/changelog.d/9041.misc b/changelog.d/9041.misc
new file mode 100644
index 0000000000..4952fbe8a2
--- /dev/null
+++ b/changelog.d/9041.misc
@@ -0,0 +1 @@
+Various cleanups to device inbox store.
diff --git a/changelog.d/9042.feature b/changelog.d/9042.feature
new file mode 100644
index 0000000000..4ec319f1f2
--- /dev/null
+++ b/changelog.d/9042.feature
@@ -0,0 +1 @@
+Add experimental support for handling and persistence of to-device messages to happen on worker processes.
diff --git a/changelog.d/9043.feature b/changelog.d/9043.feature
new file mode 100644
index 0000000000..4ec319f1f2
--- /dev/null
+++ b/changelog.d/9043.feature
@@ -0,0 +1 @@
+Add experimental support for handling and persistence of to-device messages to happen on worker processes.
diff --git a/changelog.d/9044.feature b/changelog.d/9044.feature
new file mode 100644
index 0000000000..4ec319f1f2
--- /dev/null
+++ b/changelog.d/9044.feature
@@ -0,0 +1 @@
+Add experimental support for handling and persistence of to-device messages to happen on worker processes.
diff --git a/changelog.d/9051.bugfix b/changelog.d/9051.bugfix
new file mode 100644
index 0000000000..272be9d7a3
--- /dev/null
+++ b/changelog.d/9051.bugfix
@@ -0,0 +1 @@
+Fix error handling during insertion of client IPs into the database.
diff --git a/changelog.d/9053.bugfix b/changelog.d/9053.bugfix
new file mode 100644
index 0000000000..3d8bbf11a1
--- /dev/null
+++ b/changelog.d/9053.bugfix
@@ -0,0 +1 @@
+Fix bug where we didn't correctly record CPU time spent in 'on_new_event' block.
diff --git a/changelog.d/9054.bugfix b/changelog.d/9054.bugfix
new file mode 100644
index 0000000000..0bfe951f17
--- /dev/null
+++ b/changelog.d/9054.bugfix
@@ -0,0 +1 @@
+Fix a minor bug which could cause confusing error messages from invalid configurations.
diff --git a/changelog.d/9055.misc b/changelog.d/9055.misc
new file mode 100644
index 0000000000..8e0512eb1e
--- /dev/null
+++ b/changelog.d/9055.misc
@@ -0,0 +1 @@
+Drop unused database tables.
diff --git a/changelog.d/9057.doc b/changelog.d/9057.doc
new file mode 100644
index 0000000000..d16686e7dc
--- /dev/null
+++ b/changelog.d/9057.doc
@@ -0,0 +1 @@
+Add missing user_mapping_provider configuration to the Keycloak OIDC example. Contributed by @chris-ruecker.
diff --git a/changelog.d/9058.misc b/changelog.d/9058.misc
new file mode 100644
index 0000000000..9df6796e22
--- /dev/null
+++ b/changelog.d/9058.misc
@@ -0,0 +1 @@
+Remove unused `SynapseService` class.
diff --git a/changelog.d/9059.bugfix b/changelog.d/9059.bugfix
new file mode 100644
index 0000000000..2933703ffa
--- /dev/null
+++ b/changelog.d/9059.bugfix
@@ -0,0 +1 @@
+Fix incorrect exit code when there is an error at startup.
diff --git a/changelog.d/9063.misc b/changelog.d/9063.misc
new file mode 100644
index 0000000000..22eed43147
--- /dev/null
+++ b/changelog.d/9063.misc
@@ -0,0 +1 @@
+Removes unnecessary declarations in the tests for the admin API.
\ No newline at end of file
diff --git a/changelog.d/9067.feature b/changelog.d/9067.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9067.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9068.feature b/changelog.d/9068.feature
new file mode 100644
index 0000000000..cdf1844fa7
--- /dev/null
+++ b/changelog.d/9068.feature
@@ -0,0 +1 @@
+Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes.
diff --git a/changelog.d/9069.misc b/changelog.d/9069.misc
new file mode 100644
index 0000000000..5e9e62d252
--- /dev/null
+++ b/changelog.d/9069.misc
@@ -0,0 +1 @@
+Remove `SynapseRequest.get_user_agent`.
diff --git a/changelog.d/9070.bugfix b/changelog.d/9070.bugfix
new file mode 100644
index 0000000000..72b8fe9f1c
--- /dev/null
+++ b/changelog.d/9070.bugfix
@@ -0,0 +1 @@
+Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers.
diff --git a/changelog.d/9080.misc b/changelog.d/9080.misc
new file mode 100644
index 0000000000..3da8171f5f
--- /dev/null
+++ b/changelog.d/9080.misc
@@ -0,0 +1 @@
+Remove redundant `Homeserver.get_ip_from_request` method.
diff --git a/changelog.d/9081.feature b/changelog.d/9081.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9081.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/demo/webserver.py b/demo/webserver.py
deleted file mode 100644
index ba176d3bd2..0000000000
--- a/demo/webserver.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import argparse
-import BaseHTTPServer
-import os
-import SimpleHTTPServer
-import cgi, logging
-
-from daemonize import Daemonize
-
-
-class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
-    UPLOAD_PATH = "upload"
-
-    """
-    Accept all post request as file upload
-    """
-
-    def do_POST(self):
-
-        path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
-        length = self.headers["content-length"]
-        data = self.rfile.read(int(length))
-
-        with open(path, "wb") as fh:
-            fh.write(data)
-
-        self.send_response(200)
-        self.send_header("Content-Type", "application/json")
-        self.end_headers()
-
-        # Return the absolute path of the uploaded file
-        self.wfile.write('{"url":"/%s"}' % path)
-
-
-def setup():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("directory")
-    parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
-    parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
-    args = parser.parse_args()
-
-    # Get absolute path to directory to serve, as daemonize changes to '/'
-    os.chdir(args.directory)
-    dr = os.getcwd()
-
-    httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
-
-    def run():
-        os.chdir(dr)
-        httpd.serve_forever()
-
-    daemon = Daemonize(
-        app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
-    )
-
-    daemon.start()
-
-
-if __name__ == "__main__":
-    setup()
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index e4d6f8203b..b3d413cf57 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -98,6 +98,8 @@ Body parameters:
 
 - ``deactivated``, optional. If unspecified, deactivation state will be left
   unchanged on existing accounts and set to ``false`` for new accounts.
+  A user cannot be erased by deactivating with this API. For details on deactivating users see
+  `Deactivate Account <#deactivate-account>`_.
 
 If the user already exists then optional parameters default to the current value.
 
@@ -248,6 +250,25 @@ server admin: see `README.rst <README.rst>`_.
 The erase parameter is optional and defaults to ``false``.
 An empty body may be passed for backwards compatibility.
 
+The following actions are performed when deactivating an user:
+
+- Try to unpind 3PIDs from the identity server
+- Remove all 3PIDs from the homeserver
+- Delete all devices and E2EE keys
+- Delete all access tokens
+- Delete the password hash
+- Removal from all rooms the user is a member of
+- Remove the user from the user directory
+- Reject all pending invites
+- Remove all account validity information related to the user
+
+The following additional actions are performed during deactivation if``erase``
+is set to ``true``:
+
+- Remove the user's display name
+- Remove the user's avatar URL
+- Mark the user as erased
+
 
 Reset password
 ==============
@@ -337,6 +358,10 @@ A response body like the following is returned:
         "total": 2
     }
 
+The server returns the list of rooms of which the user and the server
+are member. If the user is local, all the rooms of which the user is
+member are returned.
+
 **Parameters**
 
 The following parameters should be set in the URL:
diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot
new file mode 100644
index 0000000000..978d579ada
--- /dev/null
+++ b/docs/auth_chain_diff.dot
@@ -0,0 +1,32 @@
+digraph auth {
+    nodesep=0.5;
+    rankdir="RL";
+
+    C [label="Create (1,1)"];
+
+    BJ [label="Bob's Join (2,1)", color=red];
+    BJ2 [label="Bob's Join (2,2)", color=red];
+    BJ2 -> BJ [color=red, dir=none];
+
+    subgraph cluster_foo {
+        A1 [label="Alice's invite (4,1)", color=blue];
+        A2 [label="Alice's Join (4,2)", color=blue];
+        A3 [label="Alice's Join (4,3)", color=blue];
+        A3 -> A2 -> A1 [color=blue, dir=none];
+        color=none;
+    }
+
+    PL1 [label="Power Level (3,1)", color=darkgreen];
+    PL2 [label="Power Level (3,2)", color=darkgreen];
+    PL2 -> PL1 [color=darkgreen, dir=none];
+
+    {rank = same; C; BJ; PL1; A1;}
+
+    A1 -> C [color=grey];
+    A1 -> BJ [color=grey];
+    PL1 -> C [color=grey];
+    BJ2 -> PL1 [penwidth=2];
+
+    A3 -> PL2 [penwidth=2];
+    A1 -> PL1 -> BJ -> C [penwidth=2];
+}
diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png
new file mode 100644
index 0000000000..771c07308f
--- /dev/null
+++ b/docs/auth_chain_diff.dot.png
Binary files differdiff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md
new file mode 100644
index 0000000000..30f72a70da
--- /dev/null
+++ b/docs/auth_chain_difference_algorithm.md
@@ -0,0 +1,108 @@
+# Auth Chain Difference Algorithm
+
+The auth chain difference algorithm is used by V2 state resolution, where a
+naive implementation can be a significant source of CPU and DB usage.
+
+### Definitions
+
+A *state set* is a set of state events; e.g. the input of a state resolution
+algorithm is a collection of state sets.
+
+The *auth chain* of a set of events are all the events' auth events and *their*
+auth events, recursively (i.e. the events reachable by walking the graph induced
+by an event's auth events links).
+
+The *auth chain difference* of a collection of state sets is the union minus the
+intersection of the sets of auth chains corresponding to the state sets, i.e an
+event is in the auth chain difference if it is reachable by walking the auth
+event graph from at least one of the state sets but not from *all* of the state
+sets.
+
+## Breadth First Walk Algorithm
+
+A way of calculating the auth chain difference without calculating the full auth
+chains for each state set is to do a parallel breadth first walk (ordered by
+depth) of each state set's auth chain. By tracking which events are reachable
+from each state set we can finish early if every pending event is reachable from
+every state set.
+
+This can work well for state sets that have a small auth chain difference, but
+can be very inefficient for larger differences. However, this algorithm is still
+used if we don't have a chain cover index for the room (e.g. because we're in
+the process of indexing it).
+
+## Chain Cover Index
+
+Synapse computes auth chain differences by pre-computing a "chain cover" index
+for the auth chain in a room, allowing efficient reachability queries like "is
+event A in the auth chain of event B". This is done by assigning every event a
+*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
+between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
+is in the auth chain of `B`) if and only if either:
+
+1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
+   sequence number; or
+2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
+   `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
+
+There are actually two potential implementations, one where we store links from
+each chain to every other reachable chain (the transitive closure of the links
+graph), and one where we remove redundant links (the transitive reduction of the
+links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
+would not be stored. Synapse uses the former implementations so that it doesn't
+need to recurse to test reachability between chains.
+
+### Example
+
+An example auth graph would look like the following, where chains have been
+formed based on type/state_key and are denoted by colour and are labelled with
+`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
+are those that would be remove in the second implementation described above).
+
+![Example](auth_chain_diff.dot.png)
+
+Note that we don't include all links between events and their auth events, as
+most of those links would be redundant. For example, all events point to the
+create event, but each chain only needs the one link from it's base to the
+create event.
+
+## Using the Index
+
+This index can be used to calculate the auth chain difference of the state sets
+by looking at the chain ID and sequence numbers reachable from each state set:
+
+1. For every state set lookup the chain ID/sequence numbers of each state event
+2. Use the index to find all chains and the maximum sequence number reachable
+   from each state set.
+3. The auth chain difference is then all events in each chain that have sequence
+   numbers between the maximum sequence number reachable from *any* state set and
+   the minimum reachable by *all* state sets (if any).
+
+Note that steps 2 is effectively calculating the auth chain for each state set
+(in terms of chain IDs and sequence numbers), and step 3 is calculating the
+difference between the union and intersection of the auth chains.
+
+### Worked Example
+
+For example, given the above graph, we can calculate the difference between
+state sets consisting of:
+
+1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
+2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
+
+Using the index we see that the following auth chains are reachable from each
+state set:
+
+1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
+2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
+
+And so, for each the ranges that are in the auth chain difference:
+1. Chain 1: None, (since everything can reach the create event).
+2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
+   sets and the maximum reachable is `2` (corresponding to Bob's second join).
+3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
+   level).
+4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
+
+So the final result is: Bob's second join `(2,2)`, the second power level
+`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.
diff --git a/docs/openid.md b/docs/openid.md
index da391f74aa..ffa4238fff 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -158,6 +158,10 @@ oidc_config:
    client_id: "synapse"
    client_secret: "copy secret generated from above"
    scopes: ["openid", "profile"]
+   user_mapping_provider:
+     config:
+       localpart_template: "{{ user.preferred_username }}"
+       display_name_template: "{{ user.name }}"
 ```
 ### [Auth0][auth0]
 
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index dd981609ac..c8ae46d1b3 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1909,6 +1909,31 @@ sso:
     #
     # Synapse will look for the following templates in this directory:
     #
+    # * HTML page to prompt the user to choose an Identity Provider during
+    #   login: 'sso_login_idp_picker.html'.
+    #
+    #   This is only used if multiple SSO Identity Providers are configured.
+    #
+    #   When rendering, this template is given the following variables:
+    #     * redirect_url: the URL that the user will be redirected to after
+    #       login. Needs manual escaping (see
+    #       https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+    #
+    #     * server_name: the homeserver's name.
+    #
+    #     * providers: a list of available Identity Providers. Each element is
+    #       an object with the following attributes:
+    #         * idp_id: unique identifier for the IdP
+    #         * idp_name: user-facing name for the IdP
+    #
+    #   The rendered HTML page should contain a form which submits its results
+    #   back as a GET request, with the following query parameters:
+    #
+    #     * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+    #       to the template)
+    #
+    #     * idp: the 'idp_id' of the chosen IDP.
+    #
     # * HTML page for a confirmation step before redirecting back to the client
     #   with the login token: 'sso_redirect_confirm.html'.
     #
diff --git a/docs/systemd-with-workers/README.md b/docs/systemd-with-workers/README.md
index 8e57d4f62e..cfa36be7b4 100644
--- a/docs/systemd-with-workers/README.md
+++ b/docs/systemd-with-workers/README.md
@@ -31,7 +31,7 @@ There is no need for a separate configuration file for the master process.
 1. Adjust synapse configuration files as above.
 1. Copy the `*.service` and `*.target` files in [system](system) to
 `/etc/systemd/system`.
-1. Run `systemctl deamon-reload` to tell systemd to load the new unit files.
+1. Run `systemctl daemon-reload` to tell systemd to load the new unit files.
 1. Run `systemctl enable matrix-synapse.service`. This will configure the
 synapse master process to be started as part of the `matrix-synapse.target`
 target.
diff --git a/mypy.ini b/mypy.ini
index 5d15b7bf1c..b996867121 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -103,6 +103,7 @@ files =
   tests/replication,
   tests/test_utils,
   tests/handlers/test_password_providers.py,
+  tests/rest/client/v1/test_login.py,
   tests/rest/client/v2_alpha/test_auth.py,
   tests/util/test_stream_change_cache.py
 
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 5ad17aa90f..22dd169bfb 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -629,6 +629,7 @@ class Porter(object):
             await self._setup_state_group_id_seq()
             await self._setup_user_id_seq()
             await self._setup_events_stream_seqs()
+            await self._setup_device_inbox_seq()
 
             # Step 3. Get tables.
             self.progress.set_state("Fetching tables")
@@ -911,6 +912,32 @@ class Porter(object):
             "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
         )
 
+    async def _setup_device_inbox_seq(self):
+        """Set the device inbox sequence to the correct value.
+        """
+        curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="device_inbox",
+            keyvalues={},
+            retcol="COALESCE(MAX(stream_id), 1)",
+            allow_none=True,
+        )
+
+        curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="device_federation_outbox",
+            keyvalues={},
+            retcol="COALESCE(MAX(stream_id), 1)",
+            allow_none=True,
+        )
+
+        next_id = max(curr_local_id, curr_federation_id) + 1
+
+        def r(txn):
+            txn.execute(
+                "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,)
+            )
+
+        return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
+
 
 ##############################################
 # The following is simply UI stuff
diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi
index 3f3af59f26..0368ba4703 100644
--- a/stubs/frozendict.pyi
+++ b/stubs/frozendict.pyi
@@ -15,16 +15,7 @@
 
 # Stub for frozendict.
 
-from typing import (
-    Any,
-    Hashable,
-    Iterable,
-    Iterator,
-    Mapping,
-    overload,
-    Tuple,
-    TypeVar,
-)
+from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
 
 _KT = TypeVar("_KT", bound=Hashable)  # Key type.
 _VT = TypeVar("_VT")  # Value type.
diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi
index 68779f968e..7b9fd079d9 100644
--- a/stubs/sortedcontainers/sorteddict.pyi
+++ b/stubs/sortedcontainers/sorteddict.pyi
@@ -7,17 +7,17 @@ from typing import (
     Callable,
     Dict,
     Hashable,
-    Iterator,
-    Iterable,
     ItemsView,
+    Iterable,
+    Iterator,
     KeysView,
     List,
     Mapping,
     Optional,
     Sequence,
+    Tuple,
     Type,
     TypeVar,
-    Tuple,
     Union,
     ValuesView,
     overload,
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 522244bb57..bfac6840e6 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -16,7 +16,7 @@
 """Contains *incomplete* type hints for txredisapi.
 """
 
-from typing import List, Optional, Union, Type
+from typing import List, Optional, Type, Union
 
 class RedisProtocol:
     def publish(self, channel: str, message: bytes): ...
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 48c4d7b0be..67ecbd32ff 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
+from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
@@ -186,8 +187,8 @@ class Auth:
             AuthError if access is denied for the user in the access token
         """
         try:
-            ip_addr = self.hs.get_ip_from_request(request)
-            user_agent = request.get_user_agent("")
+            ip_addr = request.getClientIP()
+            user_agent = get_request_user_agent(request)
 
             access_token = self.get_access_token_from_request(request)
 
@@ -275,7 +276,7 @@ class Auth:
             return None, None
 
         if app_service.ip_range_whitelist:
-            ip_address = IPAddress(self.hs.get_ip_from_request(request))
+            ip_address = IPAddress(request.getClientIP())
             if ip_address not in app_service.ip_range_whitelist:
                 return None, None
 
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index f3ecbf36b6..de2cc15d33 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -51,11 +51,11 @@ class RoomDisposition:
 class RoomVersion:
     """An object which describes the unique attributes of a room version."""
 
-    identifier = attr.ib()  # str; the identifier for this version
-    disposition = attr.ib()  # str; one of the RoomDispositions
-    event_format = attr.ib()  # int; one of the EventFormatVersions
-    state_res = attr.ib()  # int; one of the StateResolutionVersions
-    enforce_key_validity = attr.ib()  # bool
+    identifier = attr.ib(type=str)  # the identifier for this version
+    disposition = attr.ib(type=str)  # one of the RoomDispositions
+    event_format = attr.ib(type=int)  # one of the EventFormatVersions
+    state_res = attr.ib(type=int)  # one of the StateResolutionVersions
+    enforce_key_validity = attr.ib(type=bool)
 
     # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
     special_case_aliases_auth = attr.ib(type=bool)
@@ -64,9 +64,11 @@ class RoomVersion:
     # * Floats
     # * NaN, Infinity, -Infinity
     strict_canonicaljson = attr.ib(type=bool)
-    # bool: MSC2209: Check 'notifications' key while verifying
+    # MSC2209: Check 'notifications' key while verifying
     # m.room.power_levels auth rules.
     limit_notifications_power_levels = attr.ib(type=bool)
+    # MSC2174/MSC2176: Apply updated redaction rules algorithm.
+    msc2176_redaction_rules = attr.ib(type=bool)
 
 
 class RoomVersions:
@@ -79,6 +81,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V2 = RoomVersion(
         "2",
@@ -89,6 +92,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V3 = RoomVersion(
         "3",
@@ -99,6 +103,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V4 = RoomVersion(
         "4",
@@ -109,6 +114,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V5 = RoomVersion(
         "5",
@@ -119,6 +125,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V6 = RoomVersion(
         "6",
@@ -129,6 +136,18 @@ class RoomVersions:
         special_case_aliases_auth=False,
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+    )
+    MSC2176 = RoomVersion(
+        "org.matrix.msc2176",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=True,
     )
 
 
@@ -141,5 +160,6 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V4,
         RoomVersions.V5,
         RoomVersions.V6,
+        RoomVersions.MSC2176,
     )
 }  # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 37ecdbe3d8..395e202b89 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2017 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,7 +20,7 @@ import signal
 import socket
 import sys
 import traceback
-from typing import Iterable
+from typing import Awaitable, Callable, Iterable
 
 from typing_extensions import NoReturn
 
@@ -143,6 +144,45 @@ def quit_with_error(error_string: str) -> NoReturn:
     sys.exit(1)
 
 
+def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
+    """Register a callback with the reactor, to be called once it is running
+
+    This can be used to initialise parts of the system which require an asynchronous
+    setup.
+
+    Any exception raised by the callback will be printed and logged, and the process
+    will exit.
+    """
+
+    async def wrapper():
+        try:
+            await cb(*args, **kwargs)
+        except Exception:
+            # previously, we used Failure().printTraceback() here, in the hope that
+            # would give better tracebacks than traceback.print_exc(). However, that
+            # doesn't handle chained exceptions (with a __cause__ or __context__) well,
+            # and I *think* the need for Failure() is reduced now that we mostly use
+            # async/await.
+
+            # Write the exception to both the logs *and* the unredirected stderr,
+            # because people tend to get confused if it only goes to one or the other.
+            #
+            # One problem with this is that if people are using a logging config that
+            # logs to the console (as is common eg under docker), they will get two
+            # copies of the exception. We could maybe try to detect that, but it's
+            # probably a cost we can bear.
+            logger.fatal("Error during startup", exc_info=True)
+            print("Error during startup:", file=sys.__stderr__)
+            traceback.print_exc(file=sys.__stderr__)
+
+            # it's no use calling sys.exit here, since that just raises a SystemExit
+            # exception which is then caught by the reactor, and everything carries
+            # on as normal.
+            os._exit(1)
+
+    reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
+
+
 def listen_metrics(bind_addresses, port):
     """
     Start Prometheus metrics server.
@@ -227,7 +267,7 @@ def refresh_certificate(hs):
         logger.info("Context factories updated.")
 
 
-def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
+async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
     """
     Start a Synapse server or worker.
 
@@ -241,75 +281,67 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
         hs: homeserver instance
         listeners: Listener configuration ('listeners' in homeserver.yaml)
     """
-    try:
-        # Set up the SIGHUP machinery.
-        if hasattr(signal, "SIGHUP"):
+    # Set up the SIGHUP machinery.
+    if hasattr(signal, "SIGHUP"):
+        reactor = hs.get_reactor()
 
-            reactor = hs.get_reactor()
+        @wrap_as_background_process("sighup")
+        def handle_sighup(*args, **kwargs):
+            # Tell systemd our state, if we're using it. This will silently fail if
+            # we're not using systemd.
+            sdnotify(b"RELOADING=1")
 
-            @wrap_as_background_process("sighup")
-            def handle_sighup(*args, **kwargs):
-                # Tell systemd our state, if we're using it. This will silently fail if
-                # we're not using systemd.
-                sdnotify(b"RELOADING=1")
+            for i, args, kwargs in _sighup_callbacks:
+                i(*args, **kwargs)
 
-                for i, args, kwargs in _sighup_callbacks:
-                    i(*args, **kwargs)
+            sdnotify(b"READY=1")
 
-                sdnotify(b"READY=1")
+        # We defer running the sighup handlers until next reactor tick. This
+        # is so that we're in a sane state, e.g. flushing the logs may fail
+        # if the sighup happens in the middle of writing a log entry.
+        def run_sighup(*args, **kwargs):
+            # `callFromThread` should be "signal safe" as well as thread
+            # safe.
+            reactor.callFromThread(handle_sighup, *args, **kwargs)
 
-            # We defer running the sighup handlers until next reactor tick. This
-            # is so that we're in a sane state, e.g. flushing the logs may fail
-            # if the sighup happens in the middle of writing a log entry.
-            def run_sighup(*args, **kwargs):
-                # `callFromThread` should be "signal safe" as well as thread
-                # safe.
-                reactor.callFromThread(handle_sighup, *args, **kwargs)
+        signal.signal(signal.SIGHUP, run_sighup)
 
-            signal.signal(signal.SIGHUP, run_sighup)
+        register_sighup(refresh_certificate, hs)
 
-            register_sighup(refresh_certificate, hs)
+    # Load the certificate from disk.
+    refresh_certificate(hs)
 
-        # Load the certificate from disk.
-        refresh_certificate(hs)
+    # Start the tracer
+    synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
+        hs
+    )
 
-        # Start the tracer
-        synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
-            hs
-        )
+    # It is now safe to start your Synapse.
+    hs.start_listening(listeners)
+    hs.get_datastore().db_pool.start_profiling()
+    hs.get_pusherpool().start()
+
+    # Log when we start the shut down process.
+    hs.get_reactor().addSystemEventTrigger(
+        "before", "shutdown", logger.info, "Shutting down..."
+    )
 
-        # It is now safe to start your Synapse.
-        hs.start_listening(listeners)
-        hs.get_datastore().db_pool.start_profiling()
-        hs.get_pusherpool().start()
+    setup_sentry(hs)
+    setup_sdnotify(hs)
 
-        # Log when we start the shut down process.
-        hs.get_reactor().addSystemEventTrigger(
-            "before", "shutdown", logger.info, "Shutting down..."
-        )
+    # If background tasks are running on the main process, start collecting the
+    # phone home stats.
+    if hs.config.run_background_tasks:
+        start_phone_stats_home(hs)
 
-        setup_sentry(hs)
-        setup_sdnotify(hs)
-
-        # If background tasks are running on the main process, start collecting the
-        # phone home stats.
-        if hs.config.run_background_tasks:
-            start_phone_stats_home(hs)
-
-        # We now freeze all allocated objects in the hopes that (almost)
-        # everything currently allocated are things that will be used for the
-        # rest of time. Doing so means less work each GC (hopefully).
-        #
-        # This only works on Python 3.7
-        if sys.version_info >= (3, 7):
-            gc.collect()
-            gc.freeze()
-    except Exception:
-        traceback.print_exc(file=sys.stderr)
-        reactor = hs.get_reactor()
-        if reactor.running:
-            reactor.stop()
-        sys.exit(1)
+    # We now freeze all allocated objects in the hopes that (almost)
+    # everything currently allocated are things that will be used for the
+    # rest of time. Doing so means less work each GC (hopefully).
+    #
+    # This only works on Python 3.7
+    if sys.version_info >= (3, 7):
+        gc.collect()
+        gc.freeze()
 
 
 def setup_sentry(hs):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index fa23d9bb20..f24c648ac7 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
 
 from typing_extensions import ContextManager
 
-from twisted.internet import address, reactor
+from twisted.internet import address
 
 import synapse
 import synapse.events
@@ -34,6 +34,7 @@ from synapse.api.urls import (
     SERVER_KEY_V2_PREFIX,
 )
 from synapse.app import _base
+from synapse.app._base import register_start
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
@@ -99,21 +100,27 @@ from synapse.rest.client.v1.profile import (
 )
 from synapse.rest.client.v1.push_rule import PushRuleRestServlet
 from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, sync, user_directory
+from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
 from synapse.rest.client.v2_alpha.account_data import (
     AccountDataServlet,
     RoomAccountDataServlet,
 )
-from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
+from synapse.rest.client.v2_alpha.keys import (
+    KeyChangesServlet,
+    KeyQueryServlet,
+    OneTimeKeyServlet,
+)
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.server import HomeServer, cache_in_self
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
 from synapse.storage.databases.main.metrics import ServerMetricsStore
 from synapse.storage.databases.main.monthly_active_users import (
@@ -445,6 +452,7 @@ class GenericWorkerSlavedStore(
     UserDirectoryStore,
     StatsStore,
     UIAuthWorkerStore,
+    EndToEndRoomKeyStore,
     SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedReceiptsStore,
@@ -502,6 +510,7 @@ class GenericWorkerServer(HomeServer):
                     LoginRestServlet(self).register(resource)
                     ThreepidRestServlet(self).register(resource)
                     KeyQueryServlet(self).register(resource)
+                    OneTimeKeyServlet(self).register(resource)
                     KeyChangesServlet(self).register(resource)
                     VoipRestServlet(self).register(resource)
                     PushRuleRestServlet(self).register(resource)
@@ -519,6 +528,9 @@ class GenericWorkerServer(HomeServer):
                     room.register_servlets(self, resource, True)
                     room.register_deprecated_servlets(self, resource)
                     InitialSyncRestServlet(self).register(resource)
+                    room_keys.register_servlets(self, resource)
+
+                    SendToDeviceRestServlet(self).register(resource)
 
                     user_directory.register_servlets(self, resource)
 
@@ -957,9 +969,7 @@ def start(config_options):
     # streams. Will no-op if no streams can be written to by this worker.
     hs.get_replication_streamer()
 
-    reactor.addSystemEventTrigger(
-        "before", "startup", _base.start, hs, config.worker_listeners
-    )
+    register_start(_base.start, hs, config.worker_listeners)
 
     _base.start_worker_reactor("synapse-generic-worker", config)
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8d9b53be53..cbecf23be6 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -15,15 +15,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import gc
 import logging
 import os
 import sys
 from typing import Iterable, Iterator
 
-from twisted.application import service
-from twisted.internet import defer, reactor
-from twisted.python.failure import Failure
+from twisted.internet import reactor
 from twisted.web.resource import EncodingResourceWrapper, IResource
 from twisted.web.server import GzipEncoderFactory
 from twisted.web.static import File
@@ -40,7 +37,7 @@ from synapse.api.urls import (
     WEB_CLIENT_PREFIX,
 )
 from synapse.app import _base
-from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
+from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
 from synapse.config._base import ConfigError
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
@@ -63,6 +60,7 @@ from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
 from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
@@ -72,7 +70,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.module_loader import load_module
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse.app.homeserver")
@@ -194,6 +191,7 @@ class SynapseHomeServer(HomeServer):
                     "/.well-known/matrix/client": WellKnownResource(self),
                     "/_synapse/admin": AdminRestResource(self),
                     "/_synapse/client/pick_username": pick_username_resource(self),
+                    "/_synapse/client/pick_idp": PickIdpResource(self),
                 }
             )
 
@@ -415,40 +413,29 @@ def setup(config_options):
             _base.refresh_certificate(hs)
 
     async def start():
-        try:
-            # Run the ACME provisioning code, if it's enabled.
-            if hs.config.acme_enabled:
-                acme = hs.get_acme_handler()
-                # Start up the webservices which we will respond to ACME
-                # challenges with, and then provision.
-                await acme.start_listening()
-                await do_acme()
+        # Run the ACME provisioning code, if it's enabled.
+        if hs.config.acme_enabled:
+            acme = hs.get_acme_handler()
+            # Start up the webservices which we will respond to ACME
+            # challenges with, and then provision.
+            await acme.start_listening()
+            await do_acme()
 
-                # Check if it needs to be reprovisioned every day.
-                hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
+            # Check if it needs to be reprovisioned every day.
+            hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
 
-            # Load the OIDC provider metadatas, if OIDC is enabled.
-            if hs.config.oidc_enabled:
-                oidc = hs.get_oidc_handler()
-                # Loading the provider metadata also ensures the provider config is valid.
-                await oidc.load_metadata()
-                await oidc.load_jwks()
+        # Load the OIDC provider metadatas, if OIDC is enabled.
+        if hs.config.oidc_enabled:
+            oidc = hs.get_oidc_handler()
+            # Loading the provider metadata also ensures the provider config is valid.
+            await oidc.load_metadata()
+            await oidc.load_jwks()
 
-            _base.start(hs, config.listeners)
+        await _base.start(hs, config.listeners)
 
-            hs.get_datastore().db_pool.updates.start_doing_background_updates()
-        except Exception:
-            # Print the exception and bail out.
-            print("Error during startup:", file=sys.stderr)
+        hs.get_datastore().db_pool.updates.start_doing_background_updates()
 
-            # this gives better tracebacks than traceback.print_exc()
-            Failure().printTraceback(file=sys.stderr)
-
-            if reactor.running:
-                reactor.stop()
-            sys.exit(1)
-
-    reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
+    register_start(start)
 
     return hs
 
@@ -485,25 +472,6 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
         e = e.__cause__
 
 
-class SynapseService(service.Service):
-    """
-    A twisted Service class that will start synapse. Used to run synapse
-    via twistd and a .tac.
-    """
-
-    def __init__(self, config):
-        self.config = config
-
-    def startService(self):
-        hs = setup(self.config)
-        change_resource_limit(hs.config.soft_file_limit)
-        if hs.config.gc_thresholds:
-            gc.set_threshold(*hs.config.gc_thresholds)
-
-    def stopService(self):
-        return self._port.stopListening()
-
-
 def run(hs):
     PROFILE_SYNAPSE = False
     if PROFILE_SYNAPSE:
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index 1bbe83c317..8fce7f6bb1 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -56,7 +56,7 @@ def json_error_to_config_error(
     """
     # copy `config_path` before modifying it.
     path = list(config_path)
-    for p in list(e.path):
+    for p in list(e.absolute_path):
         if isinstance(p, int):
             path.append("<item %i>" % p)
         else:
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 93bbd40937..1aeb1c5c92 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -31,6 +31,7 @@ class SSOConfig(Config):
 
         # Read templates from disk
         (
+            self.sso_login_idp_picker_template,
             self.sso_redirect_confirm_template,
             self.sso_auth_confirm_template,
             self.sso_error_template,
@@ -38,6 +39,7 @@ class SSOConfig(Config):
             sso_auth_success_template,
         ) = self.read_templates(
             [
+                "sso_login_idp_picker.html",
                 "sso_redirect_confirm.html",
                 "sso_auth_confirm.html",
                 "sso_error.html",
@@ -98,6 +100,31 @@ class SSOConfig(Config):
             #
             # Synapse will look for the following templates in this directory:
             #
+            # * HTML page to prompt the user to choose an Identity Provider during
+            #   login: 'sso_login_idp_picker.html'.
+            #
+            #   This is only used if multiple SSO Identity Providers are configured.
+            #
+            #   When rendering, this template is given the following variables:
+            #     * redirect_url: the URL that the user will be redirected to after
+            #       login. Needs manual escaping (see
+            #       https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * providers: a list of available Identity Providers. Each element is
+            #       an object with the following attributes:
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #
+            #   The rendered HTML page should contain a form which submits its results
+            #   back as a GET request, with the following query parameters:
+            #
+            #     * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+            #       to the template)
+            #
+            #     * idp: the 'idp_id' of the chosen IDP.
+            #
             # * HTML page for a confirmation step before redirecting back to the client
             #   with the login token: 'sso_redirect_confirm.html'.
             #
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 7ca9efec52..364583f48b 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -53,6 +53,9 @@ class WriterLocations:
         default=["master"], type=List[str], converter=_instance_to_list_converter
     )
     typing = attr.ib(default="master", type=str)
+    to_device = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
 
 
 class WorkerConfig(Config):
@@ -124,7 +127,7 @@ class WorkerConfig(Config):
 
         # Check that the configured writers for events and typing also appears in
         # `instance_map`.
-        for stream in ("events", "typing"):
+        for stream in ("events", "typing", "to_device"):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
                 if instance != "master" and instance not in self.instance_map:
@@ -133,6 +136,11 @@ class WorkerConfig(Config):
                         % (instance, stream)
                     )
 
+        if len(self.writers.to_device) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `to_device` messages."
+            )
+
         self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
         # Whether this worker should run background tasks or not.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 14f7f1156f..9c22e33813 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -79,13 +79,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
         "state_key",
         "depth",
         "prev_events",
-        "prev_state",
         "auth_events",
         "origin",
         "origin_server_ts",
-        "membership",
     ]
 
+    # Room versions from before MSC2176 had additional allowed keys.
+    if not room_version.msc2176_redaction_rules:
+        allowed_keys.extend(["prev_state", "membership"])
+
     event_type = event_dict["type"]
 
     new_content = {}
@@ -98,6 +100,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     if event_type == EventTypes.Member:
         add_fields("membership")
     elif event_type == EventTypes.Create:
+        # MSC2176 rules state that create events cannot be redacted.
+        if room_version.msc2176_redaction_rules:
+            return event_dict
+
         add_fields("creator")
     elif event_type == EventTypes.JoinRules:
         add_fields("join_rule")
@@ -112,10 +118,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
             "kick",
             "redact",
         )
+
+        if room_version.msc2176_redaction_rules:
+            add_fields("invite")
+
     elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
         add_fields("aliases")
     elif event_type == EventTypes.RoomHistoryVisibility:
         add_fields("history_visibility")
+    elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
+        add_fields("redacts")
 
     allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 35e345ce70..e5339aca23 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -860,8 +861,10 @@ class FederationHandlerRegistry:
         )  # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
         self.query_handlers = {}  # type: Dict[str, Callable[[dict], Awaitable[None]]]
 
-        # Map from type to instance name that we should route EDU handling to.
-        self._edu_type_to_instance = {}  # type: Dict[str, str]
+        # Map from type to instance names that we should route EDU handling to.
+        # We randomly choose one instance from the list to route to for each new
+        # EDU received.
+        self._edu_type_to_instance = {}  # type: Dict[str, List[str]]
 
     def register_edu_handler(
         self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
@@ -905,7 +908,12 @@ class FederationHandlerRegistry:
     def register_instance_for_edu(self, edu_type: str, instance_name: str):
         """Register that the EDU handler is on a different instance than master.
         """
-        self._edu_type_to_instance[edu_type] = instance_name
+        self._edu_type_to_instance[edu_type] = [instance_name]
+
+    def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+        """Register that the EDU handler is on multiple instances.
+        """
+        self._edu_type_to_instance[edu_type] = instance_names
 
     async def on_edu(self, edu_type: str, origin: str, content: dict):
         if not self.config.use_presence and edu_type == "m.presence":
@@ -924,8 +932,11 @@ class FederationHandlerRegistry:
             return
 
         # Check if we can route it somewhere else that isn't us
-        route_to = self._edu_type_to_instance.get(edu_type, "master")
-        if route_to != self._instance_name:
+        instances = self._edu_type_to_instance.get(edu_type, ["master"])
+        if self._instance_name not in instances:
+            # Pick an instance randomly so that we don't overload one.
+            route_to = random.choice(instances)
+
             try:
                 await self._send_edu(
                     instance_name=route_to,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4434673dc..4f881a439a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -49,8 +49,13 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+    INTERACTIVE_AUTH_CHECKERS,
+    UIAuthSessionDataConstants,
+)
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
@@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
-from ._base import BaseHandler
-
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
@@ -284,7 +287,6 @@ class AuthHandler(BaseHandler):
         requester: Requester,
         request: SynapseRequest,
         request_body: Dict[str, Any],
-        clientip: str,
         description: str,
     ) -> Tuple[dict, Optional[str]]:
         """
@@ -301,8 +303,6 @@ class AuthHandler(BaseHandler):
 
             request_body: The body of the request sent by the client
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
@@ -338,10 +338,10 @@ class AuthHandler(BaseHandler):
                 request_body.pop("auth", None)
                 return request_body, None
 
-        user_id = requester.user.to_string()
+        requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
 
         # build a list of supported flows
         supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -349,13 +349,16 @@ class AuthHandler(BaseHandler):
         )
         flows = [[login_type] for login_type in supported_ui_auth_types]
 
+        def get_new_session_data() -> JsonDict:
+            return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
+
         try:
             result, params, session_id = await self.check_ui_auth(
-                flows, request, request_body, clientip, description
+                flows, request, request_body, description, get_new_session_data,
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
             raise
 
         # find the completed login type
@@ -363,14 +366,14 @@ class AuthHandler(BaseHandler):
             if login_type not in result:
                 continue
 
-            user_id = result[login_type]
+            validated_user_id = result[login_type]
             break
         else:
             # this can't happen
             raise Exception("check_auth returned True but no successful login type")
 
         # check that the UI auth matched the access token
-        if user_id != requester.user.to_string():
+        if validated_user_id != requester_user_id:
             raise AuthError(403, "Invalid auth")
 
         # Note that the access token has been validated.
@@ -402,13 +405,9 @@ class AuthHandler(BaseHandler):
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
-        if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
-            if await self.store.get_external_ids_by_user(user.to_string()):
-                ui_auth_types.add(LoginType.SSO)
-
-        # Our CAS impl does not (yet) correctly register users in user_external_ids,
-        # so always offer that if it's available.
-        if self.hs.config.cas.cas_enabled:
+        if await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user.to_string()
+        ):
             ui_auth_types.add(LoginType.SSO)
 
         return ui_auth_types
@@ -426,8 +425,8 @@ class AuthHandler(BaseHandler):
         flows: List[List[str]],
         request: SynapseRequest,
         clientdict: Dict[str, Any],
-        clientip: str,
         description: str,
+        get_new_session_data: Optional[Callable[[], JsonDict]] = None,
     ) -> Tuple[dict, dict, str]:
         """
         Takes a dictionary sent by the client in the login / registration
@@ -448,11 +447,16 @@ class AuthHandler(BaseHandler):
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
+            get_new_session_data:
+                an optional callback which will be called when starting a new session.
+                it should return data to be stored as part of the session.
+
+                The keys of the returned data should be entries in
+                UIAuthSessionDataConstants.
+
         Returns:
             A tuple of (creds, params, session_id).
 
@@ -480,10 +484,15 @@ class AuthHandler(BaseHandler):
 
         # If there's no session ID, create a new session.
         if not sid:
+            new_session_data = get_new_session_data() if get_new_session_data else {}
+
             session = await self.store.create_ui_auth_session(
                 clientdict, uri, method, description
             )
 
+            for k, v in new_session_data.items():
+                await self.set_session_data(session.session_id, k, v)
+
         else:
             try:
                 session = await self.store.get_ui_auth_session(sid)
@@ -539,7 +548,8 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.get_user_agent("")
+        user_agent = get_request_user_agent(request)
+        clientip = request.getClientIP()
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -644,7 +654,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key to store the data under. An entry from
+                UIAuthSessionDataConstants.
             value: The data to store
         """
         try:
@@ -660,7 +671,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key the data was stored under. An entry from
+                UIAuthSessionDataConstants.
             default: Value to return if the key has not been set
         """
         try:
@@ -1334,12 +1346,12 @@ class AuthHandler(BaseHandler):
         else:
             return False
 
-    async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+    async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
         """
         Get the HTML for the SSO redirect confirmation page.
 
         Args:
-            redirect_url: The URL to redirect to the SSO provider.
+            request: The incoming HTTP request
             session_id: The user interactive authentication session ID.
 
         Returns:
@@ -1349,6 +1361,35 @@ class AuthHandler(BaseHandler):
             session = await self.store.get_ui_auth_session(session_id)
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+
+        user_id_to_verify = await self.get_session_data(
+            session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user_id_to_verify
+        )
+
+        if not idps:
+            # we checked that the user had some remote identities before offering an SSO
+            # flow, so either it's been deleted or the client has requested SSO despite
+            # it not being offered.
+            raise SynapseError(400, "User has no SSO identities")
+
+        # for now, just pick one
+        idp_id, sso_auth_provider = next(iter(idps.items()))
+        if len(idps) > 0:
+            logger.warning(
+                "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+                "picking %r",
+                user_id_to_verify,
+                idp_id,
+            )
+
+        redirect_url = await sso_auth_provider.handle_redirect_request(
+            request, None, session_id
+        )
+
         return self._sso_auth_confirm_template.render(
             description=session.description, redirect_url=redirect_url,
         )
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index fca210a5a6..f3430c6713 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -75,10 +75,15 @@ class CasHandler:
         self._http_client = hs.get_proxied_http_client()
 
         # identifier for the external_ids table
-        self._auth_provider_id = "cas"
+        self.idp_id = "cas"
+
+        # user-facing name of this auth provider
+        self.idp_name = "CAS"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _build_service_param(self, args: Dict[str, str]) -> str:
         """
         Generates a value to use as the "service" parameter when redirecting or
@@ -105,7 +110,7 @@ class CasHandler:
         Args:
             ticket: The CAS ticket from the client.
             service_args: Additional arguments to include in the service URL.
-                Should be the same as those passed to `get_redirect_url`.
+                Should be the same as those passed to `handle_redirect_request`.
 
         Raises:
             CasError: If there's an error parsing the CAS response.
@@ -184,16 +189,31 @@ class CasHandler:
 
         return CasResponse(user, attributes)
 
-    def get_redirect_url(self, service_args: Dict[str, str]) -> str:
-        """
-        Generates a URL for the CAS server where the client should be redirected.
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Generates a URL for the CAS server where the client should be redirected.
 
         Args:
-            service_args: Additional arguments to include in the final redirect URL.
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
-            The URL to redirect the client to.
+            URL to redirect to
         """
+
+        if ui_auth_session_id:
+            service_args = {"session": ui_auth_session_id}
+        else:
+            assert client_redirect_url
+            service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
         args = urllib.parse.urlencode(
             {"service": self._build_service_param(service_args)}
         )
@@ -275,7 +295,7 @@ class CasHandler:
         # first check if we're doing a UIA
         if session:
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, cas_response.username, session, request,
+                self.idp_id, cas_response.username, session, request,
             )
 
         # otherwise, we're handling a login request.
@@ -375,7 +395,7 @@ class CasHandler:
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             cas_response.username,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e808142365..c4a3b26a84 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
 
 from ._base import BaseHandler
 
@@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
         self._device_handler = hs.get_device_handler()
         self._room_member_handler = hs.get_room_member_handler()
         self._identity_handler = hs.get_identity_handler()
+        self._profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
         self._server_name = hs.hostname
 
@@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
         self._account_validity_enabled = hs.config.account_validity.enabled
 
     async def deactivate_account(
-        self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+        self,
+        user_id: str,
+        erase_data: bool,
+        requester: Requester,
+        id_server: Optional[str] = None,
+        by_admin: bool = False,
     ) -> bool:
         """Deactivate a user's account
 
         Args:
             user_id: ID of user to be deactivated
             erase_data: whether to GDPR-erase the user's data
+            requester: The user attempting to make this change.
             id_server: Use the given identity server when unbinding
                 any threepids. If None then will attempt to unbind using the
                 identity server specified when binding (if known).
+            by_admin: Whether this change was made by an administrator.
 
         Returns:
             True if identity server supports removing threepids, otherwise False.
@@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
 
         # Mark the user as erased, if they asked for that
         if erase_data:
+            user = UserID.from_string(user_id)
+            # Remove avatar URL from this user
+            await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+            # Remove displayname from this user
+            await self._profile_handler.set_displayname(user, requester, "", by_admin)
+
             logger.info("Marking %s as erased", user_id)
             await self.store.mark_user_erased(user_id)
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 9cac5a8463..fc974a82e8 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -24,6 +24,7 @@ from synapse.logging.opentracing import (
     set_tag,
     start_active_span,
 )
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
@@ -44,13 +45,37 @@ class DeviceMessageHandler:
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
         self.is_mine = hs.is_mine
-        self.federation = hs.get_federation_sender()
 
-        hs.get_federation_registry().register_edu_handler(
-            "m.direct_to_device", self.on_direct_to_device_edu
-        )
+        # We only need to poke the federation sender explicitly if its on the
+        # same instance. Other federation sender instances will get notified by
+        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+        # in the to-device replication stream.
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
+        # If we can handle the to device EDUs we do so, otherwise we route them
+        # to the appropriate worker.
+        if hs.get_instance_name() in hs.config.worker.writers.to_device:
+            hs.get_federation_registry().register_edu_handler(
+                "m.direct_to_device", self.on_direct_to_device_edu
+            )
+        else:
+            hs.get_federation_registry().register_instances_for_edu(
+                "m.direct_to_device", hs.config.worker.writers.to_device,
+            )
 
-        self._device_list_updater = hs.get_device_handler().device_list_updater
+        # The handler to call when we think a user's device list might be out of
+        # sync. We do all device list resyncing on the master instance, so if
+        # we're on a worker we hit the device resync replication API.
+        if hs.config.worker.worker_app is None:
+            self._user_device_resync = (
+                hs.get_device_handler().device_list_updater.user_device_resync
+            )
+        else:
+            self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
+                hs
+            )
 
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         local_messages = {}
@@ -138,9 +163,7 @@ class DeviceMessageHandler:
             await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
 
             # Immediately attempt a resync in the background
-            run_in_background(
-                self._device_list_updater.user_device_resync, sender_user_id
-            )
+            run_in_background(self._user_device_resync, sender_user_id)
 
     async def send_device_message(
         self,
@@ -195,7 +218,8 @@ class DeviceMessageHandler:
         )
 
         log_kv({"remote_messages": remote_messages})
-        for destination in remote_messages.keys():
-            # Enqueue a new federation transaction to send the new
-            # device messages to each remote destination.
-            self.federation.send_device_messages(destination)
+        if self.federation_sender:
+            for destination in remote_messages.keys():
+                # Enqueue a new federation transaction to send the new
+                # device messages to each remote destination.
+                self.federation_sender.send_device_messages(destination)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 709f8dfc13..88097639ef 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import inspect
 import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
 from urllib.parse import urlencode
 
 import attr
@@ -35,7 +35,6 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -85,12 +84,15 @@ class OidcError(Exception):
         return self.error
 
 
-class OidcHandler(BaseHandler):
+class OidcHandler:
     """Handles requests related to the OpenID Connect login flow.
     """
 
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        self._store = hs.get_datastore()
+
+        self._token_generator = OidcSessionTokenGenerator(hs)
+
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
         self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
@@ -116,13 +118,17 @@ class OidcHandler(BaseHandler):
 
         self._http_client = hs.get_proxied_http_client()
         self._server_name = hs.config.server_name  # type: str
-        self._macaroon_secret_key = hs.config.macaroon_secret_key
 
         # identifier for the external_ids table
-        self._auth_provider_id = "oidc"
+        self.idp_id = "oidc"
+
+        # user-facing name of this auth provider
+        self.idp_name = "OIDC"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _validate_metadata(self):
         """Verifies the provider metadata.
 
@@ -475,7 +481,7 @@ class OidcHandler(BaseHandler):
     async def handle_redirect_request(
         self,
         request: SynapseRequest,
-        client_redirect_url: bytes,
+        client_redirect_url: Optional[bytes],
         ui_auth_session_id: Optional[str] = None,
     ) -> str:
         """Handle an incoming request to /login/sso/redirect
@@ -499,7 +505,7 @@ class OidcHandler(BaseHandler):
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
-                when everything is done
+                when everything is done (or None for UI Auth)
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
@@ -511,11 +517,16 @@ class OidcHandler(BaseHandler):
         state = generate_token()
         nonce = generate_token()
 
-        cookie = self._generate_oidc_session_token(
+        if not client_redirect_url:
+            client_redirect_url = b""
+
+        cookie = self._token_generator.generate_oidc_session_token(
             state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url.decode(),
-            ui_auth_session_id=ui_auth_session_id,
+            session_data=OidcSessionData(
+                nonce=nonce,
+                client_redirect_url=client_redirect_url.decode(),
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
         request.addCookie(
             SESSION_COOKIE_NAME,
@@ -620,11 +631,9 @@ class OidcHandler(BaseHandler):
 
         # Deserialize the session token and verify it.
         try:
-            (
-                nonce,
-                client_redirect_url,
-                ui_auth_session_id,
-            ) = self._verify_oidc_session_token(session, state)
+            session_data = self._token_generator.verify_oidc_session_token(
+                session, state
+            )
         except MacaroonDeserializationException as e:
             logger.exception("Invalid session")
             self._sso_handler.render_error(request, "invalid_session", str(e))
@@ -666,14 +675,14 @@ class OidcHandler(BaseHandler):
         else:
             logger.debug("Extracting userinfo from id_token")
             try:
-                userinfo = await self._parse_id_token(token, nonce=nonce)
+                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
         # first check if we're doing a UIA
-        if ui_auth_session_id:
+        if session_data.ui_auth_session_id:
             try:
                 remote_user_id = self._remote_id_from_userinfo(userinfo)
             except Exception as e:
@@ -682,7 +691,7 @@ class OidcHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+                self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
             )
 
         # otherwise, it's a login
@@ -690,133 +699,12 @@ class OidcHandler(BaseHandler):
         # Call the mapper to register/login the user
         try:
             await self._complete_oidc_login(
-                userinfo, token, request, client_redirect_url
+                userinfo, token, request, session_data.client_redirect_url
             )
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
 
-    def _generate_oidc_session_token(
-        self,
-        state: str,
-        nonce: str,
-        client_redirect_url: str,
-        ui_auth_session_id: Optional[str],
-        duration_in_ms: int = (60 * 60 * 1000),
-    ) -> str:
-        """Generates a signed token storing data about an OIDC session.
-
-        When Synapse initiates an authorization flow, it creates a random state
-        and a random nonce. Those parameters are given to the provider and
-        should be verified when the client comes back from the provider.
-        It is also used to store the client_redirect_url, which is used to
-        complete the SSO login flow.
-
-        Args:
-            state: The ``state`` parameter passed to the OIDC provider.
-            nonce: The ``nonce`` parameter passed to the OIDC provider.
-            client_redirect_url: The URL the client gave when it initiated the
-                flow.
-            ui_auth_session_id: The session ID of the ongoing UI Auth (or
-                None if this is a login).
-            duration_in_ms: An optional duration for the token in milliseconds.
-                Defaults to an hour.
-
-        Returns:
-            A signed macaroon token with the session information.
-        """
-        macaroon = pymacaroons.Macaroon(
-            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
-        )
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("type = session")
-        macaroon.add_first_party_caveat("state = %s" % (state,))
-        macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
-        macaroon.add_first_party_caveat(
-            "client_redirect_url = %s" % (client_redirect_url,)
-        )
-        if ui_auth_session_id:
-            macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (ui_auth_session_id,)
-            )
-        now = self.clock.time_msec()
-        expiry = now + duration_in_ms
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
-        return macaroon.serialize()
-
-    def _verify_oidc_session_token(
-        self, session: bytes, state: str
-    ) -> Tuple[str, str, Optional[str]]:
-        """Verifies and extract an OIDC session token.
-
-        This verifies that a given session token was issued by this homeserver
-        and extract the nonce and client_redirect_url caveats.
-
-        Args:
-            session: The session token to verify
-            state: The state the OIDC provider gave back
-
-        Returns:
-            The nonce, client_redirect_url, and ui_auth_session_id for this session
-        """
-        macaroon = pymacaroons.Macaroon.deserialize(session)
-
-        v = pymacaroons.Verifier()
-        v.satisfy_exact("gen = 1")
-        v.satisfy_exact("type = session")
-        v.satisfy_exact("state = %s" % (state,))
-        v.satisfy_general(lambda c: c.startswith("nonce = "))
-        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
-        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
-        # to always satisfy this.
-        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
-        v.satisfy_general(self._verify_expiry)
-
-        v.verify(macaroon, self._macaroon_secret_key)
-
-        # Extract the `nonce`, `client_redirect_url`, and maybe the
-        # `ui_auth_session_id` from the token.
-        nonce = self._get_value_from_macaroon(macaroon, "nonce")
-        client_redirect_url = self._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
-        try:
-            ui_auth_session_id = self._get_value_from_macaroon(
-                macaroon, "ui_auth_session_id"
-            )  # type: Optional[str]
-        except ValueError:
-            ui_auth_session_id = None
-
-        return nonce, client_redirect_url, ui_auth_session_id
-
-    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
-        """Extracts a caveat value from a macaroon token.
-
-        Args:
-            macaroon: the token
-            key: the key of the caveat to extract
-
-        Returns:
-            The extracted value
-
-        Raises:
-            Exception: if the caveat was not in the macaroon
-        """
-        prefix = key + " = "
-        for caveat in macaroon.caveats:
-            if caveat.caveat_id.startswith(prefix):
-                return caveat.caveat_id[len(prefix) :]
-        raise ValueError("No %s caveat in macaroon" % (key,))
-
-    def _verify_expiry(self, caveat: str) -> bool:
-        prefix = "time < "
-        if not caveat.startswith(prefix):
-            return False
-        expiry = int(caveat[len(prefix) :])
-        now = self.clock.time_msec()
-        return now < expiry
-
     async def _complete_oidc_login(
         self,
         userinfo: UserInfo,
@@ -893,8 +781,8 @@ class OidcHandler(BaseHandler):
                 # and attempt to match it.
                 attributes = await oidc_response_to_user_attributes(failures=0)
 
-                user_id = UserID(attributes.localpart, self.server_name).to_string()
-                users = await self.store.get_users_by_id_case_insensitive(user_id)
+                user_id = UserID(attributes.localpart, self._server_name).to_string()
+                users = await self._store.get_users_by_id_case_insensitive(user_id)
                 if users:
                     # If an existing matrix ID is returned, then use it.
                     if len(users) == 1:
@@ -923,7 +811,7 @@ class OidcHandler(BaseHandler):
             extra_attributes = await get_extra_attributes(userinfo, token)
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
@@ -946,6 +834,148 @@ class OidcHandler(BaseHandler):
         return str(remote_user_id)
 
 
+class OidcSessionTokenGenerator:
+    """Methods for generating and checking OIDC Session cookies."""
+
+    def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
+        self._server_name = hs.hostname
+        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
+
+    def generate_oidc_session_token(
+        self,
+        state: str,
+        session_data: "OidcSessionData",
+        duration_in_ms: int = (60 * 60 * 1000),
+    ) -> str:
+        """Generates a signed token storing data about an OIDC session.
+
+        When Synapse initiates an authorization flow, it creates a random state
+        and a random nonce. Those parameters are given to the provider and
+        should be verified when the client comes back from the provider.
+        It is also used to store the client_redirect_url, which is used to
+        complete the SSO login flow.
+
+        Args:
+            state: The ``state`` parameter passed to the OIDC provider.
+            session_data: data to include in the session token.
+            duration_in_ms: An optional duration for the token in milliseconds.
+                Defaults to an hour.
+
+        Returns:
+            A signed macaroon token with the session information.
+        """
+        macaroon = pymacaroons.Macaroon(
+            location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+        )
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("type = session")
+        macaroon.add_first_party_caveat("state = %s" % (state,))
+        macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
+        macaroon.add_first_party_caveat(
+            "client_redirect_url = %s" % (session_data.client_redirect_url,)
+        )
+        if session_data.ui_auth_session_id:
+            macaroon.add_first_party_caveat(
+                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+            )
+        now = self._clock.time_msec()
+        expiry = now + duration_in_ms
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+
+        return macaroon.serialize()
+
+    def verify_oidc_session_token(
+        self, session: bytes, state: str
+    ) -> "OidcSessionData":
+        """Verifies and extract an OIDC session token.
+
+        This verifies that a given session token was issued by this homeserver
+        and extract the nonce and client_redirect_url caveats.
+
+        Args:
+            session: The session token to verify
+            state: The state the OIDC provider gave back
+
+        Returns:
+            The data extracted from the session cookie
+        """
+        macaroon = pymacaroons.Macaroon.deserialize(session)
+
+        v = pymacaroons.Verifier()
+        v.satisfy_exact("gen = 1")
+        v.satisfy_exact("type = session")
+        v.satisfy_exact("state = %s" % (state,))
+        v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+        # Sometimes there's a UI auth session ID, it seems to be OK to attempt
+        # to always satisfy this.
+        v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
+        v.satisfy_general(self._verify_expiry)
+
+        v.verify(macaroon, self._macaroon_secret_key)
+
+        # Extract the `nonce`, `client_redirect_url`, and maybe the
+        # `ui_auth_session_id` from the token.
+        nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        client_redirect_url = self._get_value_from_macaroon(
+            macaroon, "client_redirect_url"
+        )
+        try:
+            ui_auth_session_id = self._get_value_from_macaroon(
+                macaroon, "ui_auth_session_id"
+            )  # type: Optional[str]
+        except ValueError:
+            ui_auth_session_id = None
+
+        return OidcSessionData(
+            nonce=nonce,
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=ui_auth_session_id,
+        )
+
+    def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+        """Extracts a caveat value from a macaroon token.
+
+        Args:
+            macaroon: the token
+            key: the key of the caveat to extract
+
+        Returns:
+            The extracted value
+
+        Raises:
+            Exception: if the caveat was not in the macaroon
+        """
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise ValueError("No %s caveat in macaroon" % (key,))
+
+    def _verify_expiry(self, caveat: str) -> bool:
+        prefix = "time < "
+        if not caveat.startswith(prefix):
+            return False
+        expiry = int(caveat[len(prefix) :])
+        now = self._clock.time_msec()
+        return now < expiry
+
+
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+    """The attributes which are stored in a OIDC session cookie"""
+
+    # The `nonce` parameter passed to the OIDC provider.
+    nonce = attr.ib(type=str)
+
+    # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+    client_redirect_url = attr.ib(type=str)
+
+    # The session ID of the ongoing UI Auth (None if this is a login)
+    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+
+
 UserAttributeDict = TypedDict(
     "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
 )
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dee0ef45e7..c02b951031 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["displayname"]
+            return result.get("displayname")
 
     async def set_displayname(
         self,
@@ -246,7 +246,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["avatar_url"]
+            return result.get("avatar_url")
 
     async def set_avatar_url(
         self,
@@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
             )
 
+        avatar_url_to_set = new_avatar_url  # type: Optional[str]
+        if new_avatar_url == "":
+            avatar_url_to_set = None
+
         # Same like set_displayname
         if by_admin:
             requester = create_requester(
                 target_user, authenticated_entity=requester.authenticated_entity
             )
 
-        await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+        await self.store.set_profile_avatar_url(
+            target_user.localpart, avatar_url_to_set
+        )
 
         if self.hs.config.user_directory_search_all_users:
             profile = await self.store.get_profileinfo(target_user.localpart)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 1f809fa161..3bece6d668 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -365,7 +365,7 @@ class RoomCreationHandler(BaseHandler):
         creation_content = {
             "room_version": new_room_version.identifier,
             "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
-        }
+        }  # type: JsonDict
 
         # Check if old room was non-federatable
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 5fa7ab3f8b..a8376543c9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -73,27 +73,41 @@ class SamlHandler(BaseHandler):
         )
 
         # identifier for the external_ids table
-        self._auth_provider_id = "saml"
+        self.idp_id = "saml"
+
+        # user-facing name of this auth provider
+        self.idp_name = "SAML"
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         self._sso_handler = hs.get_sso_handler()
+        self._sso_handler.register_identity_provider(self)
 
-    def handle_redirect_request(
-        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
-    ) -> bytes:
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
+            request: the incoming HTTP request
             client_redirect_url: the URL that we should redirect the
-                client to when everything is done
+                client to after login (or None for UI Auth).
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
             URL to redirect to
         """
+        if not client_redirect_url:
+            # Some SAML identity providers (e.g. Google) require a
+            # RelayState parameter on requests, so pass in a dummy redirect URL
+            # (which will never get used).
+            client_redirect_url = b"unused"
+
         reqid, info = self._saml_client.prepare_for_authenticate(
             entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
@@ -210,7 +224,7 @@ class SamlHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id,
+                self.idp_id,
                 remote_user_id,
                 current_session.ui_auth_session_id,
                 request,
@@ -306,7 +320,7 @@ class SamlHandler(BaseHandler):
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 33cd6bc178..d096e0b091 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
+from urllib.parse import urlencode
 
 import attr
-from typing_extensions import NoReturn
+from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
 
-from synapse.api.errors import RedirectException, SynapseError
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@@ -40,6 +43,58 @@ class MappingException(Exception):
     """
 
 
+class SsoIdentityProvider(Protocol):
+    """Abstract base class to be implemented by SSO Identity Providers
+
+    An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+    to say whether they should be allowed to log in, or perform a given action.
+
+    Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+    and CAS.
+
+    The main entry point is `handle_redirect_request`, which should return a URI to
+    redirect the user's browser to the IdP's authentication page.
+
+    Each IdP should be registered with the SsoHandler via
+    `hs.get_sso_handler().register_identity_provider()`, so that requests to
+    `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+    """
+
+    @property
+    @abc.abstractmethod
+    def idp_id(self) -> str:
+        """A unique identifier for this SSO provider
+
+        Eg, "saml", "cas", "github"
+        """
+
+    @property
+    @abc.abstractmethod
+    def idp_name(self) -> str:
+        """User-facing name for this provider"""
+
+    @abc.abstractmethod
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Handle an incoming request to /login/sso/redirect
+
+        Args:
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            URL to redirect to
+        """
+        raise NotImplementedError()
+
+
 @attr.s
 class UserAttributes:
     # the localpart of the mxid that the mapper has assigned to the user.
@@ -100,6 +155,49 @@ class SsoHandler:
         # a map from session id to session data
         self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
 
+        # map from idp_id to SsoIdentityProvider
+        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+
+    def register_identity_provider(self, p: SsoIdentityProvider):
+        p_id = p.idp_id
+        assert p_id not in self._identity_providers
+        self._identity_providers[p_id] = p
+
+    def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
+        """Get the configured identity providers"""
+        return self._identity_providers
+
+    async def get_identity_providers_for_user(
+        self, user_id: str
+    ) -> Mapping[str, SsoIdentityProvider]:
+        """Get the SsoIdentityProviders which a user has used
+
+        Given a user id, get the identity providers that that user has used to log in
+        with in the past (and thus could use to re-identify themselves for UI Auth).
+
+        Args:
+            user_id: MXID of user to look up
+
+        Raises:
+            a map of idp_id to SsoIdentityProvider
+        """
+        external_ids = await self._store.get_external_ids_by_user(user_id)
+
+        valid_idps = {}
+        for idp_id, _ in external_ids:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                logger.warning(
+                    "User %r has an SSO mapping for IdP %r, but this is no longer "
+                    "configured.",
+                    user_id,
+                    idp_id,
+                )
+            else:
+                valid_idps[idp_id] = idp
+
+        return valid_idps
+
     def render_error(
         self,
         request: Request,
@@ -124,6 +222,34 @@ class SsoHandler:
         )
         respond_with_html(request, code, html)
 
+    async def handle_redirect_request(
+        self, request: SynapseRequest, client_redirect_url: bytes,
+    ) -> str:
+        """Handle a request to /login/sso/redirect
+
+        Args:
+            request: incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login.
+
+        Returns:
+             the URI to redirect to
+        """
+        if not self._identity_providers:
+            raise SynapseError(
+                400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+            )
+
+        # if we only have one auth provider, redirect to it directly
+        if len(self._identity_providers) == 1:
+            ap = next(iter(self._identity_providers.values()))
+            return await ap.handle_redirect_request(request, client_redirect_url)
+
+        # otherwise, redirect to the IDP picker
+        return "/_synapse/client/pick_idp?" + urlencode(
+            (("redirectUrl", client_redirect_url),)
+        )
+
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str
     ) -> Optional[str]:
@@ -268,7 +394,7 @@ class SsoHandler:
                     attributes,
                     auth_provider_id,
                     remote_user_id,
-                    request.get_user_agent(""),
+                    get_request_user_agent(request),
                     request.getClientIP(),
                 )
 
@@ -534,7 +660,7 @@ class SsoHandler:
             attributes,
             session.auth_provider_id,
             session.remote_user_id,
-            request.get_user_agent(""),
+            get_request_user_agent(request),
             request.getClientIP(),
         )
 
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
 """
 
 from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS  # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+    """Constants for use with AuthHandler.set_session_data"""
+
+    # used during registration and password reset to store a hashed copy of the
+    # password, so that the client does not need to submit it each time.
+    PASSWORD_HASH = "password_hash"
+
+    # used during registration to store the mxid of the registered user
+    REGISTERED_USER_ID = "registered_user_id"
+
+    # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+    # for.
+    REQUEST_USER_ID = "request_user_id"
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 59b01b812c..4bc3cb53f0 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -17,6 +17,7 @@ import re
 
 from twisted.internet import task
 from twisted.web.client import FileBodyProducer
+from twisted.web.iweb import IRequest
 
 from synapse.api.errors import SynapseError
 
@@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer):
             FileBodyProducer.stopProducing(self)
         except task.TaskStopped:
             pass
+
+
+def get_request_user_agent(request: IRequest, default: str = "") -> str:
+    """Return the last User-Agent header, or the given default.
+    """
+    # There could be raw utf-8 bytes in the User-Agent header.
+
+    # N.B. if you don't do this, the logger explodes cryptically
+    # with maximum recursion trying to log errors about
+    # the charset problem.
+    # c.f. https://github.com/matrix-org/synapse/issues/3471
+
+    h = request.getHeader(b"User-Agent")
+    return h.decode("ascii", "replace") if h else default
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b261e078c4..b7103d6541 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -174,6 +174,16 @@ async def _handle_json_response(
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
         body = await make_deferred_yieldable(d)
+    except ValueError as e:
+        # The JSON content was invalid.
+        logger.warning(
+            "{%s} [%s] Failed to parse JSON response - %s %s",
+            request.txn_id,
+            request.destination,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        raise RequestSendFailed(e, can_retry=False) from e
     except defer.TimeoutError as e:
         logger.warning(
             "{%s} [%s] Timed out reading response - %s %s",
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5a5790831b..12ec3f851f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
 
 from synapse.config.server import ListenerConfig
-from synapse.http import redact_uri
+from synapse.http import get_request_user_agent, redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.types import Requester
@@ -113,15 +113,6 @@ class SynapseRequest(Request):
             method = self.method.decode("ascii")
         return method
 
-    def get_user_agent(self, default: str) -> str:
-        """Return the last User-Agent header, or the given default.
-        """
-        user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
-        if user_agent is None:
-            return default
-
-        return user_agent.decode("ascii", "replace")
-
     def render(self, resrc):
         # this is called once a Resource has been found to serve the request; in our
         # case the Resource in question will normally be a JsonResource.
@@ -292,12 +283,7 @@ class SynapseRequest(Request):
             # and can see that we're doing something wrong.
             authenticated_entity = repr(self.requester)  # type: ignore[unreachable]
 
-        # ...or could be raw utf-8 bytes in the User-Agent header.
-        # N.B. if you don't do this, the logger explodes cryptically
-        # with maximum recursion trying to log errors about
-        # the charset problem.
-        # c.f. https://github.com/matrix-org/synapse/issues/3471
-        user_agent = self.get_user_agent("-")
+        user_agent = get_request_user_agent(self, "-")
 
         code = str(self.code)
         if not self.finished:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index a507a83e93..c2db8b45f3 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -252,7 +252,12 @@ class LoggingContext:
         "scope",
     ]
 
-    def __init__(self, name=None, parent_context=None, request=None) -> None:
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        parent_context: "Optional[LoggingContext]" = None,
+        request: Optional[str] = None,
+    ) -> None:
         self.previous_context = current_context()
         self.name = name
 
@@ -536,20 +541,20 @@ class LoggingContextFilter(logging.Filter):
     def __init__(self, request: str = ""):
         self._default_request = request
 
-    def filter(self, record) -> Literal[True]:
+    def filter(self, record: logging.LogRecord) -> Literal[True]:
         """Add each fields from the logging contexts to the record.
         Returns:
             True to include the record in the log output.
         """
         context = current_context()
-        record.request = self._default_request
+        record.request = self._default_request  # type: ignore
 
         # context should never be None, but if it somehow ends up being, then
         # we end up in a death spiral of infinite loops, so let's check, for
         # robustness' sake.
         if context is not None:
             # Logging is interested in the request.
-            record.request = context.request
+            record.request = context.request  # type: ignore
 
         return True
 
@@ -616,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
     return current
 
 
-def nested_logging_context(
-    suffix: str, parent_context: Optional[LoggingContext] = None
-) -> LoggingContext:
+def nested_logging_context(suffix: str) -> LoggingContext:
     """Creates a new logging context as a child of another.
 
     The nested logging context will have a 'request' made up of the parent context's
@@ -632,20 +635,23 @@ def nested_logging_context(
             # ... do stuff
 
     Args:
-        suffix (str): suffix to add to the parent context's 'request'.
-        parent_context (LoggingContext|None): parent context. Will use the current context
-            if None.
+        suffix: suffix to add to the parent context's 'request'.
 
     Returns:
         LoggingContext: new logging context.
     """
-    if parent_context is not None:
-        context = parent_context  # type: LoggingContextOrSentinel
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Starting nested logging context from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+        prefix = ""
     else:
-        context = current_context()
-    return LoggingContext(
-        parent_context=context, request=str(context.request) + "-" + suffix
-    )
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
+        prefix = str(parent_context.request)
+    return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
 
 
 def preserve_fn(f):
@@ -822,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = current_context()
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Calling defer_to_threadpool from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+    else:
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
 
     def g():
-        with LoggingContext(parent_context=logcontext):
+        with LoggingContext(parent_context=parent_context):
             return f(*args, **kwargs)
 
     return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c4c8bb271d..0745899b48 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -396,31 +396,30 @@ class Notifier:
 
         Will wake up all listeners for the given users and rooms.
         """
-        with PreserveLoggingContext():
-            with Measure(self.clock, "on_new_event"):
-                user_streams = set()
+        with Measure(self.clock, "on_new_event"):
+            user_streams = set()
 
-                for user in users:
-                    user_stream = self.user_to_user_stream.get(str(user))
-                    if user_stream is not None:
-                        user_streams.add(user_stream)
+            for user in users:
+                user_stream = self.user_to_user_stream.get(str(user))
+                if user_stream is not None:
+                    user_streams.add(user_stream)
 
-                for room in rooms:
-                    user_streams |= self.room_to_user_streams.get(room, set())
+            for room in rooms:
+                user_streams |= self.room_to_user_streams.get(room, set())
 
-                time_now_ms = self.clock.time_msec()
-                for user_stream in user_streams:
-                    try:
-                        user_stream.notify(stream_key, new_token, time_now_ms)
-                    except Exception:
-                        logger.exception("Failed to notify listener")
+            time_now_ms = self.clock.time_msec()
+            for user_stream in user_streams:
+                try:
+                    user_stream.notify(stream_key, new_token, time_now_ms)
+                except Exception:
+                    logger.exception("Failed to notify listener")
 
-                self.notify_replication()
+            self.notify_replication()
 
-                # Notify appservices
-                self._notify_app_services_ephemeral(
-                    stream_key, new_token, users,
-                )
+            # Notify appservices
+            self._notify_app_services_ephemeral(
+                stream_key, new_token, users,
+            )
 
     def on_new_replication_data(self) -> None:
         """Used to inform replication listeners that something has happened
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 10f27e4378..9018f9e20b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -203,14 +203,18 @@ class BulkPushRuleEvaluator:
 
         condition_cache = {}  # type: Dict[str, bool]
 
+        # If the event is not a state event check if any users ignore the sender.
+        if not event.is_state():
+            ignorers = await self.store.ignored_by(event.sender)
+        else:
+            ignorers = set()
+
         for uid, rules in rules_by_user.items():
             if event.sender == uid:
                 continue
 
-            if not event.is_state():
-                is_ignored = await self.store.is_ignored_by(event.sender, uid)
-                if is_ignored:
-                    continue
+            if uid in ignorers:
+                continue
 
             display_name = None
             profile_info = room_members.get(uid)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 5b045bed02..1260f6d141 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -14,46 +14,8 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import ToDeviceStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-        self._device_inbox_id_gen = SlavedIdTracker(
-            db_conn, "device_inbox", "stream_id"
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-
-        self._last_device_delete_cache = ExpiringCache(
-            cache_name="last_device_delete_cache",
-            clock=self._clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-        )
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ToDeviceStream.NAME:
-            self._device_inbox_id_gen.advance(instance_name, token)
-            for row in rows:
-                if row.entity.startswith("@"):
-                    self._device_inbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-                else:
-                    self._device_federation_outbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 95e5502bf2..1f89249475 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -56,6 +56,7 @@ from synapse.replication.tcp.streams import (
     EventsStream,
     FederationStream,
     Stream,
+    ToDeviceStream,
     TypingStream,
 )
 
@@ -115,6 +116,14 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, ToDeviceStream):
+                # Only add ToDeviceStream as a source on instances in charge of
+                # sending to device messages.
+                if hs.get_instance_name() in hs.config.worker.writers.to_device:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             if isinstance(stream, TypingStream):
                 # Only add TypingStream as a source on the instance in charge of
                 # typing.
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
new file mode 100644
index 0000000000..f53c9cd679
--- /dev/null
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -0,0 +1,28 @@
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
+        <title>{{server_name | e}} Login</title>
+    </head>
+    <body>
+        <div id="container">
+            <h1 id="title">{{server_name | e}} Login</h1>
+            <div class="login_flow">
+                <p>Choose one of the following identity providers:</p>
+            <form>
+                <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
+                <ul class="radiobuttons">
+{% for p in providers %}
+                    <li>
+                        <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
+                        <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
+                    </li>
+{% endfor %}
+                </ul>
+                <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
+            </form>
+            </div>
+        </div>
+    </body>
+</html>
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 6658c2da56..f39e3d6d5c 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -244,7 +244,7 @@ class UserRestServletV2(RestServlet):
 
                 if deactivate and not user["deactivated"]:
                     await self.deactivate_account_handler.deactivate_account(
-                        target_user.to_string(), False
+                        target_user.to_string(), False, requester, by_admin=True
                     )
                 elif not deactivate and user["deactivated"]:
                     if "password" not in body:
@@ -486,12 +486,22 @@ class WhoisRestServlet(RestServlet):
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
+        self.store = hs.get_datastore()
+
+    async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        if not self.is_mine(UserID.from_string(target_user_id)):
+            raise SynapseError(400, "Can only deactivate local users")
+
+        if not await self.store.get_user_by_id(target_user_id):
+            raise NotFoundError("User not found")
 
-    async def on_POST(self, request, target_user_id):
-        await assert_requester_is_admin(self.auth, request)
         body = parse_json_object_from_request(request, allow_empty_body=True)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -501,10 +511,8 @@ class DeactivateAccountRestServlet(RestServlet):
                 Codes.BAD_JSON,
             )
 
-        UserID.from_string(target_user_id)
-
         result = await self._deactivate_account_handler.deactivate_account(
-            target_user_id, erase
+            target_user_id, erase, requester, by_admin=True
         )
         if result:
             id_server_unbind_result = "success"
@@ -714,13 +722,6 @@ class UserMembershipRestServlet(RestServlet):
     async def on_GET(self, request, user_id):
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only lookup local users")
-
-        user = await self.store.get_user_by_id(user_id)
-        if user is None:
-            raise NotFoundError("Unknown user")
-
         room_ids = await self.store.get_rooms_for_user(user_id)
         ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
         return 200, ret
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5f4c6703db..be938df962 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-class BaseSSORedirectServlet(RestServlet):
-    """Common base class for /login/sso/redirect impls"""
-
+class SsoRedirectServlet(RestServlet):
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
+    def __init__(self, hs: "HomeServer"):
+        # make sure that the relevant handlers are instantiated, so that they
+        # register themselves with the main SSOHandler.
+        if hs.config.cas_enabled:
+            hs.get_cas_handler()
+        if hs.config.saml2_enabled:
+            hs.get_saml_handler()
+        if hs.config.oidc_enabled:
+            hs.get_oidc_handler()
+        self._sso_handler = hs.get_sso_handler()
+
     async def on_GET(self, request: SynapseRequest):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = await self.get_sso_url(request, client_redirect_url)
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding=None
+        )
+        sso_url = await self._sso_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
+        logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        """Get the URL to redirect to, to perform SSO auth
-
-        Args:
-            request: The client request to redirect.
-            client_redirect_url: the URL that we should redirect the
-                client to when everything is done
-
-        Returns:
-            URL to redirect to
-        """
-        # to be implemented by subclasses
-        raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
-    def __init__(self, hs):
-        self._cas_handler = hs.get_cas_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._cas_handler.get_redirect_url(
-            {"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-
 
 class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
         )
 
 
-class SAMLRedirectServlet(BaseSSORedirectServlet):
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._saml_handler = hs.get_saml_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
-    """Implementation for /login/sso/redirect for the OIDC login flow."""
-
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._oidc_handler = hs.get_oidc_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return await self._oidc_handler.handle_redirect_request(
-            request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
-        CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-    elif hs.config.saml2_enabled:
-        SAMLRedirectServlet(hs).register(http_server)
-    elif hs.config.oidc_enabled:
-        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d837bde1d6..65e68d641b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING
 from urllib.parse import urlparse
 
-if TYPE_CHECKING:
-    from synapse.app.homeserver import HomeServer
-
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
 )
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
@@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             try:
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
-                    requester,
-                    request,
-                    body,
-                    self.hs.get_ip_from_request(request),
-                    "modify your account password",
+                    requester, request, body, "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
                 # The user needs to provide more steps to complete auth, but
@@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
@@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet):
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
                     body,
-                    self.hs.get_ip_from_request(request),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
 
@@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet):
             password_hash = await self.auth_handler.hash(new_password)
         elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
         else:
             # UI validation was skipped, but the request did not include a new
@@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet):
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase
+                requester.user.to_string(), erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "deactivate your account",
+            requester, request, body, "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(), erase, id_server=body.get("id_server")
+            requester.user.to_string(),
+            erase,
+            requester,
+            id_server=body.get("id_server"),
         )
         if result:
             id_server_unbind_result = "success"
@@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a third-party identifier to your account",
+            requester, request, body, "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
-
-        # SSO configuration.
-        self._cas_enabled = hs.config.cas_enabled
-        if self._cas_enabled:
-            self._cas_handler = hs.get_cas_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
-        self._oidc_enabled = hs.config.oidc_enabled
-        if self._oidc_enabled:
-            self._oidc_handler = hs.get_oidc_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
         self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
-
-            elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
-            elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
-            else:
-                raise SynapseError(400, "Homeserver not configured for SSO.")
-
-            html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+            html = await self.auth_handler.start_sso_ui_auth(request, session)
 
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"response": response, "session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+                LoginType.RECAPTCHA, authdict, request.getClientIP()
             )
 
             if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+                LoginType.TERMS, authdict, request.getClientIP()
             )
 
             if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove device(s) from your account",
+            requester, request, body, "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove a device from your account",
+            requester, request, body, "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a device signing key to your account",
+            requester, request, body, "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 6b5a1b7109..b093183e79 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.config.registration import RegistrationConfig
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
             )
 
-        ip = self.hs.get_ip_from_request(request)
+        ip = request.getClientIP()
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
 
@@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
             # user here. We carry on and go through the auth checks though,
             # for paranoia.
             registered_user_id = await self.auth_handler.get_session_data(
-                session_id, "registered_user_id", None
+                session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
             )
             # Extract the previously-hashed password from the session.
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
 
         # Ensure that the username is valid.
@@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet):
         # not this will raise a user-interactive auth error.
         try:
             auth_result, params, session_id = await self.auth_handler.check_ui_auth(
-                self._registration_flows,
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "register a new account",
+                self._registration_flows, request, body, "register a new account",
             )
         except InteractiveAuthIncompleteError as e:
             # The user needs to provide more steps to complete auth.
@@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet):
             if not password_hash and password:
                 password_hash = await self.auth_handler.hash(password)
                 await self.auth_handler.set_session_data(
-                    e.session_id, "password_hash", password_hash
+                    e.session_id,
+                    UIAuthSessionDataConstants.PASSWORD_HASH,
+                    password_hash,
                 )
             raise
 
@@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet):
             # Remember that the user account has been registered (and the user
             # ID it was registered with, since it might not have been specified).
             await self.auth_handler.set_session_data(
-                session_id, "registered_user_id", registered_user_id
+                session_id,
+                UIAuthSessionDataConstants.REGISTERED_USER_ID,
+                registered_user_id,
             )
 
             registered = True
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
new file mode 100644
index 0000000000..e5b720bbca
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    finish_request,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PickIdpResource(DirectServeHtmlResource):
+    """IdP picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
+    which prompts the user to choose an Identity Provider from the list.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._sso_login_idp_picker_template = (
+            hs.config.sso.sso_login_idp_picker_template
+        )
+        self._server_name = hs.hostname
+
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
+        client_redirect_url = parse_string(request, "redirectUrl", required=True)
+        idp = parse_string(request, "idp", required=False)
+
+        # if we need to pick an IdP, do so
+        if not idp:
+            return await self._serve_id_picker(request, client_redirect_url)
+
+        # otherwise, redirect to the IdP's redirect URI
+        providers = self._sso_handler.get_identity_providers()
+        auth_provider = providers.get(idp)
+        if not auth_provider:
+            logger.info("Unknown idp %r", idp)
+            self._sso_handler.render_error(
+                request, "unknown_idp", "Unknown identity provider ID"
+            )
+            return
+
+        sso_url = await auth_provider.handle_redirect_request(
+            request, client_redirect_url.encode("utf8")
+        )
+        logger.info("Redirecting to %s", sso_url)
+        request.redirect(sso_url)
+        finish_request(request)
+
+    async def _serve_id_picker(
+        self, request: SynapseRequest, client_redirect_url: str
+    ) -> None:
+        # otherwise, serve up the IdP picker
+        providers = self._sso_handler.get_identity_providers()
+        html = self._sso_login_idp_picker_template.render(
+            redirect_url=client_redirect_url,
+            server_name=self._server_name,
+            providers=providers.values(),
+        )
+        respond_with_html(request, 200, html)
diff --git a/synapse/server.py b/synapse/server.py
index a198b0eb46..d4c235cda5 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -283,10 +283,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         """
         return self._reactor
 
-    def get_ip_from_request(self, request) -> str:
-        # X-Forwarded-For is handled by our custom request type.
-        return request.getClientIP()
-
     def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
         return domain_specific_string.domain == self.hostname
 
@@ -505,7 +501,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return InitialSyncHandler(self)
 
     @cache_in_self
-    def get_profile_handler(self):
+    def get_profile_handler(self) -> ProfileHandler:
         return ProfileHandler(self)
 
     @cache_in_self
diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css
index 83e4f6abc8..dd76714a92 100644
--- a/synapse/static/client/login/style.css
+++ b/synapse/static/client/login/style.css
@@ -31,6 +31,11 @@ form {
     margin: 10px 0 0 0;
 }
 
+ul.radiobuttons {
+    text-align: left;
+    list-style: none;
+}
+
 /*
  * Add some padding to the viewport.
  */
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d1b5760c2c..6cfadc2b4e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -42,7 +42,6 @@ from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
-    LoggingContextOrSentinel,
     current_context,
     make_deferred_yieldable,
 )
@@ -180,6 +179,9 @@ class LoggingDatabaseConnection:
 _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
 
 
+R = TypeVar("R")
+
+
 class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -267,6 +269,20 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
+    def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+        """Corresponds to psycopg2.extras.execute_values. Only available when
+        using postgres.
+
+        Always sets fetch=True when caling `execute_values`, so will return the
+        results.
+        """
+        assert isinstance(self.database_engine, PostgresEngine)
+        from psycopg2.extras import execute_values  # type: ignore
+
+        return self._do_execute(
+            lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+        )
+
     def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
@@ -277,7 +293,7 @@ class LoggingTransaction:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql: str, *args: Any) -> None:
+    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -348,9 +364,6 @@ class PerformanceCounters:
         return top_n_counters
 
 
-R = TypeVar("R")
-
-
 class DatabasePool:
     """Wraps a single physical database and connection pool.
 
@@ -671,12 +684,15 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
-        if not parent_context:
+        curr_context = current_context()
+        if not curr_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
             parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
 
         start_time = monotonic_time()
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 701748f93b..c4de07a0a8 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -127,9 +127,6 @@ class DataStore(
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
-        self._device_inbox_id_gen = StreamIdGenerator(
-            db_conn, "device_inbox", "stream_id"
-        )
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
@@ -189,36 +186,6 @@ class DataStore(
             prefilled_cache=presence_cache_prefill,
         )
 
-        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
-        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_inbox",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            min_device_inbox_id,
-            prefilled_cache=device_inbox_prefill,
-        )
-        # The federation outbox and the local device inbox uses the same
-        # stream_id generator.
-        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_federation_outbox",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            min_device_outbox_id,
-            prefilled_cache=device_outbox_prefill,
-        )
-
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
             "DeviceListStreamChangeCache", device_list_max
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 49ee23470d..bad8260892 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,7 +16,7 @@
 
 import abc
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -24,7 +24,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import _CacheContext, cached
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -287,35 +287,34 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
-    @cached(num_args=2, cache_context=True, max_entries=5000)
-    async def is_ignored_by(
-        self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
-    ) -> bool:
-        ignored_account_data = await self.get_global_account_data_by_type_for_user(
-            AccountDataTypes.IGNORED_USER_LIST,
-            ignorer_user_id,
-            on_invalidate=cache_context.invalidate,
-        )
-        if not ignored_account_data:
-            return False
+    @cached(max_entries=5000, iterable=True)
+    async def ignored_by(self, user_id: str) -> Set[str]:
+        """
+        Get users which ignore the given user.
 
-        try:
-            return ignored_user_id in ignored_account_data.get("ignored_users", {})
-        except TypeError:
-            # The type of the ignored_users field is invalid.
-            return False
+        Params:
+            user_id: The user ID which might be ignored.
+
+        Return:
+            The user IDs which ignore the given user.
+        """
+        return set(
+            await self.db_pool.simple_select_onecol(
+                table="ignored_users",
+                keyvalues={"ignored_user_id": user_id},
+                retcol="ignorer_user_id",
+                desc="ignored_by",
+            )
+        )
 
 
 class AccountDataStore(AccountDataWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self._account_data_id_gen = StreamIdGenerator(
             db_conn,
-            "account_data_max_stream_id",
+            "room_account_data",
             "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
+            extra_tables=[("room_tags_revisions", "stream_id")],
         )
 
         super().__init__(database, db_conn, hs)
@@ -360,14 +359,6 @@ class AccountDataStore(AccountDataWorkerStore):
                 lock=False,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
@@ -390,32 +381,16 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
-        content_json = json_encoder.encode(content)
-
         async with self._account_data_id_gen.get_next() as next_id:
-            # no need to lock here as account_data has a unique constraint on
-            # (user_id, account_data_type) so simple_upsert will retry if
-            # there is a conflict.
-            await self.db_pool.simple_upsert(
-                desc="add_user_account_data",
-                table="account_data",
-                keyvalues={"user_id": user_id, "account_data_type": account_data_type},
-                values={"stream_id": next_id, "content": content_json},
-                lock=False,
+            await self.db_pool.runInteraction(
+                "add_user_account_data",
+                self._add_account_data_for_user,
+                next_id,
+                user_id,
+                account_data_type,
+                content,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            #
-            # Note: This is only here for backwards compat to allow admins to
-            # roll back to a previous Synapse version. Next time we update the
-            # database version we can remove this table.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
@@ -424,23 +399,67 @@ class AccountDataStore(AccountDataWorkerStore):
 
         return self._account_data_id_gen.get_current_token()
 
-    async def _update_max_stream_id(self, next_id: int) -> None:
-        """Update the max stream_id
+    def _add_account_data_for_user(
+        self,
+        txn,
+        next_id: int,
+        user_id: str,
+        account_data_type: str,
+        content: JsonDict,
+    ) -> None:
+        content_json = json_encoder.encode(content)
 
-        Args:
-            next_id: The the revision to advance to.
-        """
+        # no need to lock here as account_data has a unique constraint on
+        # (user_id, account_data_type) so simple_upsert will retry if
+        # there is a conflict.
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="account_data",
+            keyvalues={"user_id": user_id, "account_data_type": account_data_type},
+            values={"stream_id": next_id, "content": content_json},
+            lock=False,
+        )
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
+        # Ignored users get denormalized into a separate table as an optimisation.
+        if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
+            return
 
-        def _update(txn):
-            update_max_id_sql = (
-                "UPDATE account_data_max_stream_id"
-                " SET stream_id = ?"
-                " WHERE stream_id < ?"
+        # Insert / delete to sync the list of ignored users.
+        previously_ignored_users = set(
+            self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="ignored_users",
+                keyvalues={"ignorer_user_id": user_id},
+                retcol="ignored_user_id",
             )
-            txn.execute(update_max_id_sql, (next_id, next_id))
+        )
+
+        # If the data is invalid, no one is ignored.
+        ignored_users_content = content.get("ignored_users", {})
+        if isinstance(ignored_users_content, dict):
+            currently_ignored_users = set(ignored_users_content)
+        else:
+            currently_ignored_users = set()
+
+        # Delete entries which are no longer ignored.
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="ignored_users",
+            column="ignored_user_id",
+            iterable=previously_ignored_users - currently_ignored_users,
+            keyvalues={"ignorer_user_id": user_id},
+        )
+
+        # Add entries which are newly ignored.
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="ignored_users",
+            values=[
+                {"ignorer_user_id": user_id, "ignored_user_id": u}
+                for u in currently_ignored_users - previously_ignored_users
+            ],
+        )
 
-        await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+        # Invalidate the cache for any ignored users which were added or removed.
+        for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
+            self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index e96a8b3f43..c53c836337 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -470,43 +470,35 @@ class ClientIpStore(ClientIpWorkerStore):
         for entry in to_update.items():
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
-            try:
-                self.db_pool.simple_upsert_txn(
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="user_ips",
+                keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
+                values={
+                    "user_agent": user_agent,
+                    "device_id": device_id,
+                    "last_seen": last_seen,
+                },
+                lock=False,
+            )
+
+            # Technically an access token might not be associated with
+            # a device so we need to check.
+            if device_id:
+                # this is always an update rather than an upsert: the row should
+                # already exist, and if it doesn't, that may be because it has been
+                # deleted, and we don't want to re-create it.
+                self.db_pool.simple_update_txn(
                     txn,
-                    table="user_ips",
-                    keyvalues={
-                        "user_id": user_id,
-                        "access_token": access_token,
-                        "ip": ip,
-                    },
-                    values={
+                    table="devices",
+                    keyvalues={"user_id": user_id, "device_id": device_id},
+                    updatevalues={
                         "user_agent": user_agent,
-                        "device_id": device_id,
                         "last_seen": last_seen,
+                        "ip": ip,
                     },
-                    lock=False,
                 )
 
-                # Technically an access token might not be associated with
-                # a device so we need to check.
-                if device_id:
-                    # this is always an update rather than an upsert: the row should
-                    # already exist, and if it doesn't, that may be because it has been
-                    # deleted, and we don't want to re-create it.
-                    self.db_pool.simple_update_txn(
-                        txn,
-                        table="devices",
-                        keyvalues={"user_id": user_id, "device_id": device_id},
-                        updatevalues={
-                            "user_agent": user_agent,
-                            "last_seen": last_seen,
-                            "ip": ip,
-                        },
-                    )
-            except Exception as e:
-                # Failed to upsert, log and continue
-                logger.error("Failed to insert client IP %r: %r", entry, e)
-
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
     ) -> Dict[Tuple[str, str], dict]:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index d42faa3f1f..58d3f71e45 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -17,15 +17,100 @@ import logging
 from typing import List, Tuple
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
 
 
 class DeviceInboxWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+        # Map of (user_id, device_id) to the last stream_id that has been
+        # deleted up to. This is so that we can no op deletions.
+        self._last_device_delete_cache = ExpiringCache(
+            cache_name="last_device_delete_cache",
+            clock=self._clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+        )
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_device = (
+                self._instance_name in hs.config.worker.writers.to_device
+            )
+
+            self._device_inbox_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="to_device",
+                instance_name=self._instance_name,
+                table="device_inbox",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="device_inbox_sequence",
+                writers=hs.config.worker.writers.to_device,
+            )
+        else:
+            self._can_write_to_device = True
+            self._device_inbox_id_gen = StreamIdGenerator(
+                db_conn, "device_inbox", "stream_id"
+            )
+
+        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_inbox",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_inbox_stream_cache = StreamChangeCache(
+            "DeviceInboxStreamChangeCache",
+            min_device_inbox_id,
+            prefilled_cache=device_inbox_prefill,
+        )
+
+        # The federation outbox and the local device inbox uses the same
+        # stream_id generator.
+        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_federation_outbox",
+            entity_column="destination",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_federation_outbox_stream_cache = StreamChangeCache(
+            "DeviceFederationOutboxStreamChangeCache",
+            min_device_outbox_id,
+            prefilled_cache=device_outbox_prefill,
+        )
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == ToDeviceStream.NAME:
+            self._device_inbox_id_gen.advance(instance_name, token)
+            for row in rows:
+                if row.entity.startswith("@"):
+                    self._device_inbox_stream_cache.entity_has_changed(
+                        row.entity, token
+                    )
+                else:
+                    self._device_federation_outbox_stream_cache.entity_has_changed(
+                        row.entity, token
+                    )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
+
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
@@ -278,52 +363,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
 
-
-class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
-    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self.db_pool.updates.register_background_index_update(
-            "device_inbox_stream_index",
-            index_name="device_inbox_stream_id_user_id",
-            table="device_inbox",
-            columns=["stream_id", "user_id"],
-        )
-
-        self.db_pool.updates.register_background_update_handler(
-            self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
-        )
-
-    async def _background_drop_index_device_inbox(self, progress, batch_size):
-        def reindex_txn(conn):
-            txn = conn.cursor()
-            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
-            txn.close()
-
-        await self.db_pool.runWithConnection(reindex_txn)
-
-        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
-        return 1
-
-
-class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
-    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        # Map of (user_id, device_id) to the last stream_id that has been
-        # deleted up to. This is so that we can no op deletions.
-        self._last_device_delete_cache = ExpiringCache(
-            cache_name="last_device_delete_cache",
-            clock=self._clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-        )
-
     @trace
     async def add_messages_to_device_inbox(
         self,
@@ -342,6 +381,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             The new stream_id.
         """
 
+        assert self._can_write_to_device
+
         def add_messages_txn(txn, now_ms, stream_id):
             # Add the local messages directly to the local inbox.
             self._add_messages_to_local_device_inbox_txn(
@@ -351,16 +392,20 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             # Add the remote messages to the federation outbox.
             # We'll send them to a remote server when we next send a
             # federation transaction to that destination.
-            sql = (
-                "INSERT INTO device_federation_outbox"
-                " (destination, stream_id, queued_ts, messages_json)"
-                " VALUES (?,?,?,?)"
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="device_federation_outbox",
+                values=[
+                    {
+                        "destination": destination,
+                        "stream_id": stream_id,
+                        "queued_ts": now_ms,
+                        "messages_json": json_encoder.encode(edu),
+                        "instance_name": self._instance_name,
+                    }
+                    for destination, edu in remote_messages_by_destination.items()
+                ],
             )
-            rows = []
-            for destination, edu in remote_messages_by_destination.items():
-                edu_json = json_encoder.encode(edu)
-                rows.append((destination, stream_id, now_ms, edu_json))
-            txn.executemany(sql, rows)
 
         async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
@@ -379,6 +424,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     async def add_messages_from_remote_to_device_inbox(
         self, origin: str, message_id: str, local_messages_by_user_then_device: dict
     ) -> int:
+        assert self._can_write_to_device
+
         def add_messages_txn(txn, now_ms, stream_id):
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
@@ -427,38 +474,45 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     def _add_messages_to_local_device_inbox_txn(
         self, txn, stream_id, messages_by_user_then_device
     ):
+        assert self._can_write_to_device
+
         local_by_user_then_device = {}
         for user_id, messages_by_device in messages_by_user_then_device.items():
             messages_json_for_user = {}
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
-                sql = "SELECT device_id FROM devices WHERE user_id = ?"
-                txn.execute(sql, (user_id,))
+                devices = self.db_pool.simple_select_onecol_txn(
+                    txn,
+                    table="devices",
+                    keyvalues={"user_id": user_id},
+                    retcol="device_id",
+                )
+
                 message_json = json_encoder.encode(messages_by_device["*"])
-                for row in txn:
+                for device_id in devices:
                     # Add the message for all devices for this user on this
                     # server.
-                    device = row[0]
-                    messages_json_for_user[device] = message_json
+                    messages_json_for_user[device_id] = message_json
             else:
                 if not devices:
                     continue
 
-                clause, args = make_in_list_sql_clause(
-                    txn.database_engine, "device_id", devices
+                rows = self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="devices",
+                    keyvalues={"user_id": user_id},
+                    column="device_id",
+                    iterable=devices,
+                    retcols=("device_id",),
                 )
-                sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
 
-                # TODO: Maybe this needs to be done in batches if there are
-                # too many local devices for a given user.
-                txn.execute(sql, [user_id] + list(args))
-                for row in txn:
+                for row in rows:
                     # Only insert into the local inbox if the device exists on
                     # this server
-                    device = row[0]
-                    message_json = json_encoder.encode(messages_by_device[device])
-                    messages_json_for_user[device] = message_json
+                    device_id = row["device_id"]
+                    message_json = json_encoder.encode(messages_by_device[device_id])
+                    messages_json_for_user[device_id] = message_json
 
             if messages_json_for_user:
                 local_by_user_then_device[user_id] = messages_json_for_user
@@ -466,14 +520,52 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
         if not local_by_user_then_device:
             return
 
-        sql = (
-            "INSERT INTO device_inbox"
-            " (user_id, device_id, stream_id, message_json)"
-            " VALUES (?,?,?,?)"
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="device_inbox",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "stream_id": stream_id,
+                    "message_json": message_json,
+                    "instance_name": self._instance_name,
+                }
+                for user_id, messages_by_device in local_by_user_then_device.items()
+                for device_id, message_json in messages_by_device.items()
+            ],
         )
-        rows = []
-        for user_id, messages_by_device in local_by_user_then_device.items():
-            for device_id, message_json in messages_by_device.items():
-                rows.append((user_id, device_id, stream_id, message_json))
 
-        txn.executemany(sql, rows)
+
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
+    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_index_update(
+            "device_inbox_stream_index",
+            index_name="device_inbox_stream_id_user_id",
+            table="device_inbox",
+            columns=["stream_id", "user_id"],
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
+        )
+
+    async def _background_drop_index_device_inbox(self, progress, batch_size):
+        def reindex_txn(conn):
+            txn = conn.cursor()
+            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+            txn.close()
+
+        await self.db_pool.runWithConnection(reindex_txn)
+
+        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+        return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+    pass
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4d1b92d1aa..1b6ccd51c8 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -707,50 +707,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         """Get the current stream id from the _device_list_id_gen"""
         ...
 
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    async def set_e2e_device_keys(
-        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
-    ) -> bool:
-        """Stores device keys for a device. Returns whether there was a change
-        or the keys were already in the database.
-        """
-
-        def _set_e2e_device_keys_txn(txn):
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
-
-            old_key_json = self.db_pool.simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
-
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
-
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
-
-        return await self.db_pool.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
-        )
-
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
     ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
@@ -840,6 +796,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    async def set_e2e_device_keys(
+        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+    ) -> bool:
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+
+        def _set_e2e_device_keys_txn(txn):
+            set_tag("user_id", user_id)
+            set_tag("device_id", device_id)
+            set_tag("time_now", time_now)
+            set_tag("device_keys", device_keys)
+
+            old_key_json = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            # In py3 we need old_key_json to match new_key_json type. The DB
+            # returns unicode while encode_canonical_json returns bytes.
+            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+            if old_key_json == new_key_json:
+                log_kv({"Message": "Device key already stored."})
+                return False
+
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                values={"ts_added_ms": time_now, "key_json": new_key_json},
+            )
+            log_kv({"message": "Device keys stored."})
+            return True
+
+        return await self.db_pool.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        )
+
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn):
             log_kv(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ebffd89251..8326640d20 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Cursor
 from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
 logger = logging.getLogger(__name__)
 
 
+class _NoChainCoverIndex(Exception):
+    def __init__(self, room_id: str):
+        super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
+
+
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             The set of the difference in auth chains.
         """
 
+        # Check if we have indexed the room so we can use the chain cover
+        # algorithm.
+        room = await self.get_room(room_id)
+        if room["has_auth_chain_index"]:
+            try:
+                return await self.db_pool.runInteraction(
+                    "get_auth_chain_difference_chains",
+                    self._get_auth_chain_difference_using_cover_index_txn,
+                    room_id,
+                    state_sets,
+                )
+            except _NoChainCoverIndex:
+                # For whatever reason we don't actually have a chain cover index
+                # for the events in question, so we fall back to the old method.
+                pass
+
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
             self._get_auth_chain_difference_txn,
             state_sets,
         )
 
+    def _get_auth_chain_difference_using_cover_index_txn(
+        self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+    ) -> Set[str]:
+        """Calculates the auth chain difference using the chain index.
+
+        See docs/auth_chain_difference_algorithm.md for details
+        """
+
+        # First we look up the chain ID/sequence numbers for all the events, and
+        # work out the chain/sequence numbers reachable from each state set.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Map from event_id -> (chain ID, seq no)
+        chain_info = {}  # type: Dict[str, Tuple[int, int]]
+
+        # Map from chain ID -> seq no -> event Id
+        chain_to_event = {}  # type: Dict[int, Dict[int, str]]
+
+        # All the chains that we've found that are reachable from the state
+        # sets.
+        seen_chains = set()  # type: Set[int]
+
+        sql = """
+            SELECT event_id, chain_id, sequence_number
+            FROM event_auth_chains
+            WHERE %s
+        """
+        for batch in batch_iter(initial_events, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for event_id, chain_id, sequence_number in txn:
+                chain_info[event_id] = (chain_id, sequence_number)
+                seen_chains.add(chain_id)
+                chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+        # Check that we actually have a chain ID for all the events.
+        events_missing_chain_info = initial_events.difference(chain_info)
+        if events_missing_chain_info:
+            # This can happen due to e.g. downgrade/upgrade of the server. We
+            # raise an exception and fall back to the previous algorithm.
+            logger.info(
+                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+                room_id,
+                events_missing_chain_info,
+            )
+            raise _NoChainCoverIndex(room_id)
+
+        # Corresponds to `state_sets`, except as a map from chain ID to max
+        # sequence number reachable from the state set.
+        set_to_chain = []  # type: List[Dict[int, int]]
+        for state_set in state_sets:
+            chains = {}  # type: Dict[int, int]
+            set_to_chain.append(chains)
+
+            for event_id in state_set:
+                chain_id, seq_no = chain_info[event_id]
+
+                chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
+
+        # Now we look up all links for the chains we have, adding chains to
+        # set_to_chain that are reachable from each set.
+        sql = """
+            SELECT
+                origin_chain_id, origin_sequence_number,
+                target_chain_id, target_sequence_number
+            FROM event_auth_chain_links
+            WHERE %s
+        """
+
+        # (We need to take a copy of `seen_chains` as we want to mutate it in
+        # the loop)
+        for batch in batch_iter(set(seen_chains), 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "origin_chain_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for (
+                origin_chain_id,
+                origin_sequence_number,
+                target_chain_id,
+                target_sequence_number,
+            ) in txn:
+                for chains in set_to_chain:
+                    # chains are only reachable if the origin sequence number of
+                    # the link is less than the max sequence number in the
+                    # origin chain.
+                    if origin_sequence_number <= chains.get(origin_chain_id, 0):
+                        chains[target_chain_id] = max(
+                            target_sequence_number, chains.get(target_chain_id, 0),
+                        )
+
+                seen_chains.add(target_chain_id)
+
+        # Now for each chain we figure out the maximum sequence number reachable
+        # from *any* state set and the minimum sequence number reachable from
+        # *all* state sets. Events in that range are in the auth chain
+        # difference.
+        result = set()
+
+        # Mapping from chain ID to the range of sequence numbers that should be
+        # pulled from the database.
+        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]
+
+        for chain_id in seen_chains:
+            min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
+            max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+
+            if min_seq_no < max_seq_no:
+                # We have a non empty gap, try and fill it from the events that
+                # we have, otherwise add them to the list of gaps to pull out
+                # from the DB.
+                for seq_no in range(min_seq_no + 1, max_seq_no + 1):
+                    event_id = chain_to_event.get(chain_id, {}).get(seq_no)
+                    if event_id:
+                        result.add(event_id)
+                    else:
+                        chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
+                        break
+
+        if not chain_to_gap:
+            # If there are no gaps to fetch, we're done!
+            return result
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND min_seq < sequence_number AND sequence_number <= max_seq
+            """
+
+            args = [
+                (chain_id, min_no, max_no)
+                for chain_id, (min_no, max_no) in chain_to_gap.items()
+            ]
+
+            rows = txn.execute_values(sql, args)
+            result.update(r for r, in rows)
+        else:
+            # For SQLite we just fall back to doing a noddy for loop.
+            sql = """
+                SELECT event_id FROM event_auth_chains
+                WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
+            """
+            for chain_id, (min_no, max_no) in chain_to_gap.items():
+                txn.execute(sql, (chain_id, min_no, max_no))
+                result.update(r for r, in txn)
+
+        return result
+
     def _get_auth_chain_difference_txn(
         self, txn, state_sets: List[Set[str]]
     ) -> Set[str]:
+        """Calculates the auth chain difference using a breadth first search.
+
+        This is used when we don't have a cover index for the room.
+        """
 
         # Algorithm Description
         # ~~~~~~~~~~~~~~~~~~~~~
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 90fb1a1f00..186f064036 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,17 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
-from synapse.util.iterutils import batch_iter
+from synapse.util.iterutils import batch_iter, sorted_topologically
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -89,6 +100,14 @@ class PersistEventsStore:
         self._clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
 
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self._event_chain_id_gen = build_sequence_generator(
+            db.engine, get_chain_id_txn, "event_auth_chain_id"
+        )
+
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
@@ -366,6 +385,36 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
+        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn, events_and_contexts=events_and_contexts
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # We call this last as it assumes we've inserted the events into
+        # room_memberships, where applicable.
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+    def _persist_event_auth_chain_txn(
+        self, txn: LoggingTransaction, events: List[EventBase],
+    ) -> None:
+
+        # We only care about state events, so this if there are no state events.
+        if not any(e.is_state() for e in events):
+            return
+
         # We want to store event_auth mappings for rejected events, as they're
         # used in state res v2.
         # This is only necessary if the rejected event appears in an accepted
@@ -381,31 +430,357 @@ class PersistEventsStore:
                     "room_id": event.room_id,
                     "auth_id": auth_id,
                 }
-                for event, _ in events_and_contexts
+                for event in events
                 for auth_id in event.auth_event_ids()
                 if event.is_state()
             ],
         )
 
-        # _store_rejected_events_txn filters out any events which were
-        # rejected, and returns the filtered list.
-        events_and_contexts = self._store_rejected_events_txn(
-            txn, events_and_contexts=events_and_contexts
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
+
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="rooms",
+            column="room_id",
+            iterable={event.room_id for event in events if event.is_state()},
+            keyvalues={},
+            retcols=("room_id", "has_auth_chain_index"),
         )
+        rooms_using_chain_index = {
+            row["room_id"] for row in rows if row["has_auth_chain_index"]
+        }
 
-        # From this point onwards the events are only ones that weren't
-        # rejected.
+        state_events = {
+            event.event_id: event
+            for event in events
+            if event.is_state() and event.room_id in rooms_using_chain_index
+        }
 
-        self._update_metadata_tables_txn(
+        if not state_events:
+            return
+
+        # Map from event ID to chain ID/sequence number.
+        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {
+            e.event_id: (e.type, e.state_key) for e in state_events.values()
+        }
+        event_to_auth_chain = {
+            e.event_id: e.auth_event_ids() for e in state_events.values()
+        }
+
+        # Set of event IDs to calculate chain ID/seq numbers for.
+        events_to_calc_chain_id_for = set(state_events)
+
+        # We check if there are any events that need to be handled in the rooms
+        # we're looking at. These should just be out of band memberships, where
+        # we didn't have the auth chain when we first persisted.
+        rows = self.db_pool.simple_select_many_txn(
             txn,
-            events_and_contexts=events_and_contexts,
-            all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="room_id",
+            iterable={e.room_id for e in state_events.values()},
+            retcols=("event_id", "type", "state_key"),
         )
+        for row in rows:
+            event_id = row["event_id"]
+            event_type = row["type"]
+            state_key = row["state_key"]
+
+            # (We could pull out the auth events for all rows at once using
+            # simple_select_many, but this case happens rarely and almost always
+            # with a single row.)
+            auth_events = self.db_pool.simple_select_onecol_txn(
+                txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+            )
 
-        # We call this last as it assumes we've inserted the events into
-        # room_memberships, where applicable.
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+            events_to_calc_chain_id_for.add(event_id)
+            event_to_types[event_id] = (event_type, state_key)
+            event_to_auth_chain[event_id] = auth_events
+
+        # First we get the chain ID and sequence numbers for the events'
+        # auth events (that aren't also currently being persisted).
+        #
+        # Note that there there is an edge case here where we might not have
+        # calculated chains and sequence numbers for events that were "out
+        # of band". We handle this case by fetching the necessary info and
+        # adding it to the set of events to calculate chain IDs for.
+
+        missing_auth_chains = {
+            a_id
+            for auth_events in event_to_auth_chain.values()
+            for a_id in auth_events
+            if a_id not in events_to_calc_chain_id_for
+        }
+
+        # We loop here in case we find an out of band membership and need to
+        # fetch their auth event info.
+        while missing_auth_chains:
+            sql = """
+                SELECT event_id, events.type, state_key, chain_id, sequence_number
+                FROM events
+                INNER JOIN state_events USING (event_id)
+                LEFT JOIN event_auth_chains USING (event_id)
+                WHERE
+            """
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", missing_auth_chains,
+            )
+            txn.execute(sql + clause, args)
+
+            missing_auth_chains.clear()
+
+            for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+                event_to_types[auth_id] = (event_type, state_key)
+
+                if chain_id is None:
+                    # No chain ID, so the event was persisted out of band.
+                    # We add to list of events to calculate auth chains for.
+
+                    events_to_calc_chain_id_for.add(auth_id)
+
+                    event_to_auth_chain[
+                        auth_id
+                    ] = self.db_pool.simple_select_onecol_txn(
+                        txn,
+                        "event_auth",
+                        keyvalues={"event_id": auth_id},
+                        retcol="auth_id",
+                    )
+
+                    missing_auth_chains.update(
+                        e
+                        for e in event_to_auth_chain[auth_id]
+                        if e not in event_to_types
+                    )
+                else:
+                    chain_map[auth_id] = (chain_id, sequence_number)
+
+        # Now we check if we have any events where we don't have auth chain,
+        # this should only be out of band memberships.
+        for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
+            for auth_id in event_to_auth_chain[event_id]:
+                if (
+                    auth_id not in chain_map
+                    and auth_id not in events_to_calc_chain_id_for
+                ):
+                    events_to_calc_chain_id_for.discard(event_id)
+
+                    # If this is an event we're trying to persist we add it to
+                    # the list of events to calculate chain IDs for next time
+                    # around. (Otherwise we will have already added it to the
+                    # table).
+                    event = state_events.get(event_id)
+                    if event:
+                        self.db_pool.simple_insert_txn(
+                            txn,
+                            table="event_auth_chain_to_calculate",
+                            values={
+                                "event_id": event.event_id,
+                                "room_id": event.room_id,
+                                "type": event.type,
+                                "state_key": event.state_key,
+                            },
+                        )
+
+                    # We stop checking the event's auth events since we've
+                    # discarded it.
+                    break
+
+        if not events_to_calc_chain_id_for:
+            return
+
+        # We now calculate the chain IDs/sequence numbers for the events. We
+        # do this by looking at the chain ID and sequence number of any auth
+        # event with the same type/state_key and incrementing the sequence
+        # number by one. If there was no match or the chain ID/sequence
+        # number is already taken we generate a new chain.
+        #
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events
+        # before the event itself.
+        chains_tuples_allocated = set()  # type: Set[Tuple[int, int]]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[int, int]]
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            existing_chain_id = None
+            for auth_id in event_to_auth_chain[event_id]:
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map[auth_id]
+                    break
+
+            new_chain_tuple = None
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+                if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
+                    already_allocated = self.db_pool.simple_select_one_onecol_txn(
+                        txn,
+                        table="event_auth_chains",
+                        keyvalues={
+                            "chain_id": proposed_new_id,
+                            "sequence_number": proposed_new_seq,
+                        },
+                        retcol="event_id",
+                        allow_none=True,
+                    )
+                    if already_allocated:
+                        # Mark it as already allocated so we don't need to hit
+                        # the DB again.
+                        chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
+                    else:
+                        new_chain_tuple = (
+                            proposed_new_id,
+                            proposed_new_seq,
+                        )
+
+            if not new_chain_tuple:
+                new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
+
+            chains_tuples_allocated.add(new_chain_tuple)
+
+            chain_map[event_id] = new_chain_tuple
+            new_chain_tuples[event_id] = new_chain_tuple
+
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            values=[
+                {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+                for event_id, (c_id, seq) in new_chain_tuples.items()
+            ],
+        )
+
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            iterable=new_chain_tuples,
+        )
+
+        # Now we need to calculate any new links between chains caused by
+        # the new events.
+        #
+        # Links are pairs of chain ID/sequence numbers such that for any
+        # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
+        # if and only if there is at least one link (CA, S1) -> (CB, S2)
+        # where SA >= S1 and S2 >= SB.
+        #
+        # We try and avoid adding redundant links to the table, e.g. if we
+        # have two links between two chains which both start/end at the
+        # sequence number event (or cross) then one can be safely dropped.
+        #
+        # To calculate new links we look at every new event and:
+        #   1. Fetch the chain ID/sequence numbers of its auth events,
+        #      discarding any that are reachable by other auth events, or
+        #      that have the same chain ID as the event.
+        #   2. For each retained auth event we:
+        #       a. Add a link from the event's to the auth event's chain
+        #          ID/sequence number; and
+        #       b. Add a link from the event to every chain reachable by the
+        #          auth event.
+
+        # Step 1, fetch all existing links from all the chains we've seen
+        # referenced.
+        chain_links = _LinkMap()
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            column="origin_chain_id",
+            iterable={chain_id for chain_id, _ in chain_map.values()},
+            keyvalues={},
+            retcols=(
+                "origin_chain_id",
+                "origin_sequence_number",
+                "target_chain_id",
+                "target_sequence_number",
+            ),
+        )
+        for row in rows:
+            chain_links.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+                new=False,
+            )
+
+        # We do this in toplogical order to avoid adding redundant links.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            chain_id, sequence_number = chain_map[event_id]
+
+            # Filter out auth events that are reachable by other auth
+            # events. We do this by looking at every permutation of pairs of
+            # auth events (A, B) to check if B is reachable from A.
+            reduction = {
+                a_id
+                for a_id in event_to_auth_chain[event_id]
+                if chain_map[a_id][0] != chain_id
+            }
+            for start_auth_id, end_auth_id in itertools.permutations(
+                event_to_auth_chain[event_id], r=2,
+            ):
+                if chain_links.exists_path_from(
+                    chain_map[start_auth_id], chain_map[end_auth_id]
+                ):
+                    reduction.discard(end_auth_id)
+
+            # Step 2, figure out what the new links are from the reduced
+            # list of auth events.
+            for auth_id in reduction:
+                auth_chain_id, auth_sequence_number = chain_map[auth_id]
+
+                # Step 2a, add link between the event and auth event
+                chain_links.add_link(
+                    (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
+                )
+
+                # Step 2b, add a link to chains reachable from the auth
+                # event.
+                for target_id, target_seq in chain_links.get_links_from(
+                    (auth_chain_id, auth_sequence_number)
+                ):
+                    if target_id == chain_id:
+                        continue
+
+                    chain_links.add_link(
+                        (chain_id, sequence_number), (target_id, target_seq)
+                    )
+
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            values=[
+                {
+                    "origin_chain_id": source_id,
+                    "origin_sequence_number": source_seq,
+                    "target_chain_id": target_id,
+                    "target_sequence_number": target_seq,
+                }
+                for (
+                    source_id,
+                    source_seq,
+                    target_id,
+                    target_seq,
+                ) in chain_links.get_additions()
+            ],
+        )
 
     def _persist_transaction_ids_txn(
         self,
@@ -799,7 +1174,8 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _store_event_txn(self, txn, events_and_contexts):
-        """Insert new events into the event and event_json tables
+        """Insert new events into the event, event_json, redaction and
+        state_events tables.
 
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
@@ -871,6 +1247,29 @@ class PersistEventsStore:
                     updatevalues={"have_censored": False},
                 )
 
+        state_events_and_contexts = [
+            ec for ec in events_and_contexts if ec[0].is_state()
+        ]
+
+        state_values = []
+        for event, context in state_events_and_contexts:
+            vals = {
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "type": event.type,
+                "state_key": event.state_key,
+            }
+
+            # TODO: How does this work with backfilling?
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
+
+            state_values.append(vals)
+
+        self.db_pool.simple_insert_many_txn(
+            txn, table="state_events", values=state_values
+        )
+
     def _store_rejected_events_txn(self, txn, events_and_contexts):
         """Add rows to the 'rejections' table for received events which were
         rejected
@@ -987,29 +1386,6 @@ class PersistEventsStore:
             txn, [event for event, _ in events_and_contexts]
         )
 
-        state_events_and_contexts = [
-            ec for ec in events_and_contexts if ec[0].is_state()
-        ]
-
-        state_values = []
-        for event, context in state_events_and_contexts:
-            vals = {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "type": event.type,
-                "state_key": event.state_key,
-            }
-
-            # TODO: How does this work with backfilling?
-            if hasattr(event, "replaces_state"):
-                vals["prev_state"] = event.replaces_state
-
-            state_values.append(vals)
-
-        self.db_pool.simple_insert_many_txn(
-            txn, table="state_events", values=state_values
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
@@ -1520,3 +1896,131 @@ class PersistEventsStore:
                 if not ev.internal_metadata.is_outlier()
             ],
         )
+
+
+@attr.s(slots=True)
+class _LinkMap:
+    """A helper type for tracking links between chains.
+    """
+
+    # Stores the set of links as nested maps: source chain ID -> target chain ID
+    # -> source sequence number -> target sequence number.
+    maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+
+    # Stores the links that have been added (with new set to true), as tuples of
+    # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
+    additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+
+    def add_link(
+        self,
+        src_tuple: Tuple[int, int],
+        target_tuple: Tuple[int, int],
+        new: bool = True,
+    ) -> bool:
+        """Add a new link between two chains, ensuring no redundant links are added.
+
+        New links should be added in topological order.
+
+        Args:
+            src_tuple: The chain ID/sequence number of the source of the link.
+            target_tuple: The chain ID/sequence number of the target of the link.
+            new: Whether this is a "new" link, i.e. should it be returned
+                by `get_additions`.
+
+        Returns:
+            True if a link was added, false if the given link was dropped as redundant
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
+
+        assert src_chain != target_chain
+
+        if new:
+            # Check if the new link is redundant
+            for current_seq_src, current_seq_target in current_links.items():
+                # If a link "crosses" another link then its redundant. For example
+                # in the following link 1 (L1) is redundant, as any event reachable
+                # via L1 is *also* reachable via L2.
+                #
+                #   Chain A     Chain B
+                #      |          |
+                #   L1 |------    |
+                #      |     |    |
+                #   L2 |---- | -->|
+                #      |     |    |
+                #      |     |--->|
+                #      |          |
+                #      |          |
+                #
+                # So we only need to keep links which *do not* cross, i.e. links
+                # that both start and end above or below an existing link.
+                #
+                # Note, since we add links in topological ordering we should never
+                # see `src_seq` less than `current_seq_src`.
+
+                if current_seq_src <= src_seq and target_seq <= current_seq_target:
+                    # This new link is redundant, nothing to do.
+                    return False
+
+            self.additions.add((src_chain, src_seq, target_chain, target_seq))
+
+        current_links[src_seq] = target_seq
+        return True
+
+    def get_links_from(
+        self, src_tuple: Tuple[int, int]
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the chains reachable from the given chain/sequence number.
+
+        Yields:
+            The chain ID and sequence number the link points to.
+        """
+        src_chain, src_seq = src_tuple
+        for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
+            for link_src_seq, target_seq in sequence_numbers.items():
+                if link_src_seq <= src_seq:
+                    yield target_id, target_seq
+
+    def get_links_between(
+        self, source_chain: int, target_chain: int
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the links between two chains.
+
+        Yields:
+            The source and target sequence numbers.
+        """
+
+        yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
+
+    def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
+        """Gets any newly added links.
+
+        Yields:
+            The source chain ID/sequence number and target chain ID/sequence number
+        """
+
+        for src_chain, src_seq, target_chain, _ in self.additions:
+            target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
+            if target_seq is not None:
+                yield (src_chain, src_seq, target_chain, target_seq)
+
+    def exists_path_from(
+        self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+    ) -> bool:
+        """Checks if there is a path between the source chain ID/sequence and
+        target chain ID/sequence.
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        if src_chain == target_chain:
+            return target_seq <= src_seq
+
+        links = self.get_links_between(src_chain, target_chain)
+        for link_start_seq, link_end_seq in links:
+            if link_start_seq <= src_seq and target_seq <= link_end_seq:
+                return True
+
+        return False
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 97b6754846..7e4b175d08 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,10 +14,15 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple
 
 from synapse.api.constants import EventContentFields
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
@@ -99,6 +104,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             columns=["user_id", "created_ts"],
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "rejected_events_metadata", self._rejected_events_metadata,
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -582,3 +591,118 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             await self.db_pool.updates._end_background_update("event_store_labels")
 
         return num_rows
+
+    async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
+        """Adds rejected events to the `state_events` and `event_auth` metadata
+        tables.
+        """
+
+        last_event_id = progress.get("last_event_id", "")
+
+        def get_rejected_events(
+            txn: Cursor,
+        ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
+            # Fetch rejected event json, their room version and whether we have
+            # inserted them into the state_events or auth_events tables.
+            #
+            # Note we can assume that events that don't have a corresponding
+            # room version are V1 rooms.
+            sql = """
+                SELECT DISTINCT
+                    event_id,
+                    COALESCE(room_version, '1'),
+                    json,
+                    state_events.event_id IS NOT NULL,
+                    event_auth.event_id IS NOT NULL
+                FROM rejections
+                INNER JOIN event_json USING (event_id)
+                LEFT JOIN rooms USING (room_id)
+                LEFT JOIN state_events USING (event_id)
+                LEFT JOIN event_auth USING (event_id)
+                WHERE event_id > ?
+                ORDER BY event_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_event_id, batch_size,))
+
+            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
+
+        results = await self.db_pool.runInteraction(
+            desc="_rejected_events_metadata_get", func=get_rejected_events
+        )
+
+        if not results:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+            return 0
+
+        state_events = []
+        auth_events = []
+        for event_id, room_version, event_json, has_state, has_event_auth in results:
+            last_event_id = event_id
+
+            if has_state and has_event_auth:
+                continue
+
+            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
+            if not room_version_obj:
+                # We no longer support this room version, so we just ignore the
+                # events entirely.
+                logger.info(
+                    "Ignoring event with unknown room version %r: %r",
+                    room_version,
+                    event_id,
+                )
+                continue
+
+            event = make_event_from_dict(event_json, room_version_obj)
+
+            if not event.is_state():
+                continue
+
+            if not has_state:
+                state_events.append(
+                    {
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    }
+                )
+
+            if not has_event_auth:
+                for auth_id in event.auth_event_ids():
+                    auth_events.append(
+                        {
+                            "room_id": event.room_id,
+                            "event_id": event.event_id,
+                            "auth_id": auth_id,
+                        }
+                    )
+
+        if state_events:
+            await self.db_pool.simple_insert_many(
+                table="state_events",
+                values=state_events,
+                desc="_rejected_events_metadata_state_events",
+            )
+
+        if auth_events:
+            await self.db_pool.simple_insert_many(
+                table="event_auth",
+                values=auth_events,
+                desc="_rejected_events_metadata_event_auth",
+            )
+
+        await self.db_pool.updates._background_update_progress(
+            "rejected_events_metadata", {"last_event_id": last_event_id}
+        )
+
+        if len(results) < batch_size:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+
+        return len(results)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 0e25ca3d7a..54ef0f1f54 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def set_profile_avatar_url(
-        self, user_localpart: str, new_avatar_url: str
+        self, user_localpart: str, new_avatar_url: Optional[str]
     ) -> None:
         await self.db_pool.simple_update_one(
             table="profiles",
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4650d0689b..284f2ce77c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore):
         return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
-            retcols=("room_id", "is_public", "creator"),
+            retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
             desc="get_room",
             allow_none=True,
         )
@@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         # It's overridden by RoomStore for the synapse master.
         raise NotImplementedError()
 
+    async def has_auth_chain_index(self, room_id: str) -> bool:
+        """Check if the room has (or can have) a chain cover index.
+
+        Defaults to True if we don't have an entry in `rooms` table nor any
+        events for the room.
+        """
+
+        has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcol="has_auth_chain_index",
+            desc="has_auth_chain_index",
+            allow_none=True,
+        )
+
+        if has_auth_chain_index:
+            return True
+
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        max_ordering = await self.db_pool.simple_select_one_onecol(
+            table="events",
+            keyvalues={"room_id": room_id},
+            retcol="MAX(stream_ordering)",
+            allow_none=True,
+            desc="upsert_room_on_join",
+        )
+
+        return max_ordering is None
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         Called when we join a room over federation, and overwrites any room version
         currently in the table.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="upsert_room_on_join",
             table="rooms",
             keyvalues={"room_id": room_id},
             values={"room_version": room_version.identifier},
-            insertion_values={"is_public": False, "creator": ""},
+            insertion_values={
+                "is_public": False,
+                "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
+            },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
             lock=False,
@@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         "creator": room_creator_user_id,
                         "is_public": is_public,
                         "room_version": room_version.identifier,
+                        "has_auth_chain_index": True,
                     },
                 )
                 if is_public:
@@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         When we receive an invite or any other event over federation that may relate to a room
         we are not in, store the version of the room if we don't already know the room version.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="maybe_store_room_on_outlier_membership",
             table="rooms",
@@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 "room_version": room_version.identifier,
                 "is_public": False,
                 "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
             },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
new file mode 100644
index 0000000000..de57645019
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE access_tokens DROP COLUMN last_used;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
new file mode 100644
index 0000000000..ee0e3521bf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Dropping last_used column from access_tokens table.
+
+CREATE TABLE access_tokens2 (
+    id BIGINT PRIMARY KEY, 
+    user_id TEXT NOT NULL, 
+    device_id TEXT, 
+    token TEXT NOT NULL,
+    valid_until_ms BIGINT,
+    puppets_user_id TEXT,
+    last_validated BIGINT,
+    UNIQUE(token) 
+);
+
+INSERT INTO access_tokens2(id, user_id, device_id, token)
+    SELECT id, user_id, device_id, token FROM access_tokens;
+
+DROP TABLE access_tokens;
+ALTER TABLE access_tokens2 RENAME TO access_tokens;
+
+CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id);
+
+
+-- Re-adding foreign key reference in event_txn_id table
+
+CREATE TABLE event_txn_id2 (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    token_id BIGINT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (token_id)
+        REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts)
+    SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id;
+
+DROP TABLE event_txn_id;
+ALTER TABLE event_txn_id2 RENAME TO event_txn_id;
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
new file mode 100644
index 0000000000..9c95646281
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5828, 'rejected_events_metadata', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
new file mode 100644
index 0000000000..f35c70b699
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -0,0 +1,82 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration denormalises the account_data table into an ignored users table.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage._base import db_to_json
+from synapse.storage.engines import BaseDatabaseEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    logger.info("Creating ignored_users table")
+    execute_statements_from_stream(cur, StringIO(_create_commands))
+
+    # We now upgrade existing data, if any. We don't do this in `run_upgrade` as
+    # we a) want to run these before adding constraints and b) `run_upgrade` is
+    # not run on empty databases.
+    insert_sql = """
+    INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?)
+    """
+
+    logger.info("Converting existing ignore lists")
+    cur.execute(
+        "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'"
+    )
+    for user_id, content_json in cur.fetchall():
+        content = db_to_json(content_json)
+
+        # The content should be the form of a dictionary with a key
+        # "ignored_users" pointing to a dictionary with keys of ignored users.
+        #
+        # { "ignored_users": "@someone:example.org": {} }
+        ignored_users = content.get("ignored_users", {})
+        if isinstance(ignored_users, dict) and ignored_users:
+            cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+
+    # Add indexes after inserting data for efficiency.
+    logger.info("Adding constraints to ignored_users table")
+    execute_statements_from_stream(cur, StringIO(_constraints_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_create_commands = """
+-- Users which are ignored when calculating push notifications. This data is
+-- denormalized from account data.
+CREATE TABLE IF NOT EXISTS ignored_users(
+    ignorer_user_id TEXT NOT NULL,  -- The user ID of the user who is ignoring another user. (This is a local user.)
+    ignored_user_id TEXT NOT NULL  -- The user ID of the user who is being ignored. (This is a local or remote user.)
+);
+"""
+
+_constraints_commands = """
+CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id);
+
+-- Add an index on ignored_users since look-ups are done to get all ignorers of an ignored user.
+CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id);
+"""
diff --git a/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
new file mode 100644
index 0000000000..d781a92fec
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE device_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
new file mode 100644
index 0000000000..45a845a3a5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence;
+
+-- We need to take the max across both device_inbox and device_federation_outbox
+-- tables as they share the ID generator
+SELECT setval('device_inbox_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox)
+    )
+));
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
new file mode 100644
index 0000000000..729196cfd5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
@@ -0,0 +1,52 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- See docs/auth_chain_difference_algorithm.md
+
+CREATE TABLE event_auth_chains (
+  event_id TEXT PRIMARY KEY,
+  chain_id BIGINT NOT NULL,
+  sequence_number BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
+
+
+CREATE TABLE event_auth_chain_links (
+  origin_chain_id BIGINT NOT NULL,
+  origin_sequence_number BIGINT NOT NULL,
+
+  target_chain_id BIGINT NOT NULL,
+  target_sequence_number BIGINT NOT NULL
+);
+
+
+CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
+
+
+-- Events that we have persisted but not calculated auth chains for,
+-- e.g. out of band memberships (where we don't have the auth chain)
+CREATE TABLE event_auth_chain_to_calculate (
+  event_id TEXT PRIMARY KEY,
+  room_id TEXT NOT NULL,
+  type TEXT NOT NULL,
+  state_key TEXT NOT NULL
+);
+
+CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
+
+
+-- Whether we've calculated the above index for a room.
+ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
new file mode 100644
index 0000000000..e8a035bbeb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
new file mode 100644
index 0000000000..64ab696cfe
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS account_data_max_stream_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
new file mode 100644
index 0000000000..fb71b360a0
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS cache_invalidation_stream;
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 9f120d3cb6..74da9c49f2 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -255,16 +255,6 @@ class TagsStore(TagsWorkerStore):
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
         )
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
-        update_max_id_sql = (
-            "UPDATE account_data_max_stream_id"
-            " SET stream_id = ?"
-            " WHERE stream_id < ?"
-        )
-        txn.execute(update_max_id_sql, (next_id, next_id))
-
         update_sql = (
             "UPDATE room_tags_revisions"
             " SET stream_id = ?"
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f91a2eae7a..566ea19bae 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -35,10 +35,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-# XXX: If you're about to bump this to 59 (or higher) please create an update
-# that drops the unused `cache_invalidation_stream` table, as per #7436!
-# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
-SCHEMA_VERSION = 58
+SCHEMA_VERSION = 59
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -375,7 +372,16 @@ def _upgrade_existing_database(
     specific_engine_extensions = (".sqlite", ".postgres")
 
     for v in range(start_ver, SCHEMA_VERSION + 1):
-        logger.info("Applying schema deltas for v%d", v)
+        if not is_worker:
+            logger.info("Applying schema deltas for v%d", v)
+
+            cur.execute("DELETE FROM schema_version")
+            cur.execute(
+                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
+                (v, True),
+            )
+        else:
+            logger.info("Checking schema deltas for v%d", v)
 
         # We need to search both the global and per data store schema
         # directories for schema updates.
@@ -489,12 +495,6 @@ def _upgrade_existing_database(
                 (v, relative_path),
             )
 
-            cur.execute("DELETE FROM schema_version")
-            cur.execute(
-                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
-                (v, True),
-            )
-
     logger.info("Schema now up to date")
 
 
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 601305487c..1adc92eb90 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -105,7 +105,7 @@ class DeferredCache(Generic[KT, VT]):
             keylen=keylen,
             cache_name=name,
             cache_type=cache_type,
-            size_callback=(lambda d: len(d)) if iterable else None,
+            size_callback=(lambda d: len(d) or 1) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
         )  # type: LruCache[KT, VT]
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..f7b4857a84 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import heapq
 from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    Mapping,
+    Sequence,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from synapse.types import Collection
 
 T = TypeVar("T")
 
@@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
     If the input is empty, no chunks are returned.
     """
     return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+    nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+    """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+    For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+    """
+
+    # This is implemented by Kahn's algorithm.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph = {}  # type: Dict[T, Set[T]]
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in edges:
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+    heapq.heapify(zero_degree)
+
+    while zero_degree:
+        node = heapq.heappop(zero_degree)
+        yield node
+
+        for edge in reverse_graph[node]:
+            if edge in degree_map:
+                degree_map[edge] -= 1
+                if degree_map[edge] == 0:
+                    heapq.heappush(zero_degree, edge)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ffdea0de8d..f4de6b9f54 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -108,7 +108,16 @@ class Measure:
     def __init__(self, clock, name):
         self.clock = clock
         self.name = name
-        parent_context = current_context()
+        curr_context = current_context()
+        if not curr_context:
+            logger.warning(
+                "Starting metrics collection %r from sentinel context: metrics will be lost",
+                name,
+            )
+            parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
         self._logging_context = LoggingContext(
             "Measure[%s]" % (self.name,), parent_context
         )
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
new file mode 100644
index 0000000000..10363e3765
--- /dev/null
+++ b/tests/config/test_util.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config import ConfigError
+from synapse.config._util import validate_config
+
+from tests.unittest import TestCase
+
+
+class ValidateConfigTestCase(TestCase):
+    """Test cases for synapse.config._util.validate_config"""
+
+    def test_bad_object_in_array(self):
+        """malformed objects within an array should be validated correctly"""
+
+        # consider a structure:
+        #
+        # array_of_objs:
+        #   - r: 1
+        #     foo: 2
+        #
+        #   - r: 2
+        #     bar: 3
+        #
+        # ... where each entry must contain an "r": check that the path
+        # to the required item is correclty reported.
+
+        schema = {
+            "type": "object",
+            "properties": {
+                "array_of_objs": {
+                    "type": "array",
+                    "items": {"type": "object", "required": ["r"]},
+                },
+            },
+        }
+
+        with self.assertRaises(ConfigError) as c:
+            validate_config(schema, {"array_of_objs": [{}]}, ("base",))
+
+        self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index c1274c14af..8ba36c6074 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -34,11 +34,17 @@ def MockEvent(**kwargs):
 
 
 class PruneEventTestCase(unittest.TestCase):
-    """ Asserts that a new event constructed with `evdict` will look like
-    `matchdict` when it is redacted. """
-
     def run_test(self, evdict, matchdict, **kwargs):
-        self.assertEquals(
+        """
+        Asserts that a new event constructed with `evdict` will look like
+        `matchdict` when it is redacted.
+
+        Args:
+             evdict: The dictionary to build the event from.
+             matchdict: The expected resulting dictionary.
+             kwargs: Additional keyword arguments used to create the event.
+        """
+        self.assertEqual(
             prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
         )
 
@@ -55,54 +61,80 @@ class PruneEventTestCase(unittest.TestCase):
         )
 
     def test_basic_keys(self):
+        """Ensure that the keys that should be untouched are kept."""
+        # Note that some of the values below don't really make sense, but the
+        # pruning of events doesn't worry about the values of any fields (with
+        # the exception of the content field).
         self.run_test(
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "content": {"other_key": "foo"},
+                "hashes": "hashes",
+                "signatures": {"domain": {"algo:1": "sigs"}},
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
+                # Also include a key that should be removed.
+                "other_key": "foo",
             },
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "hashes": "hashes",
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
                 "content": {},
-                "signatures": {},
+                "signatures": {"domain": {"algo:1": "sigs"}},
                 "unsigned": {},
             },
         )
 
-    def test_unsigned_age_ts(self):
+        # As of MSC2176 we now redact the membership and prev_states keys.
         self.run_test(
-            {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
-            {
-                "type": "B",
-                "event_id": "$test:domain",
-                "content": {},
-                "signatures": {},
-                "unsigned": {"age_ts": 20},
-            },
+            {"type": "A", "prev_state": "prev_state", "membership": "join"},
+            {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
+            room_version=RoomVersions.MSC2176,
         )
 
+    def test_unsigned(self):
+        """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
         self.run_test(
             {
                 "type": "B",
                 "event_id": "$test:domain",
-                "unsigned": {"other_key": "here"},
+                "unsigned": {
+                    "age_ts": 20,
+                    "replaces_state": "$test2:domain",
+                    "other_key": "foo",
+                },
             },
             {
                 "type": "B",
                 "event_id": "$test:domain",
                 "content": {},
                 "signatures": {},
-                "unsigned": {},
+                "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"},
             },
         )
 
     def test_content(self):
+        """The content dictionary should be stripped in most cases."""
         self.run_test(
             {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
             {
@@ -114,11 +146,35 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # Some events keep a single content key/value.
+        EVENT_KEEP_CONTENT_KEYS = [
+            ("member", "membership", "join"),
+            ("join_rules", "join_rule", "invite"),
+            ("history_visibility", "history_visibility", "shared"),
+        ]
+        for event_type, key, value in EVENT_KEEP_CONTENT_KEYS:
+            self.run_test(
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value, "other_key": "foo"},
+                },
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value},
+                    "signatures": {},
+                    "unsigned": {},
+                },
+            )
+
+    def test_create(self):
+        """Create events are partially redacted until MSC2176."""
         self.run_test(
             {
                 "type": "m.room.create",
                 "event_id": "$test:domain",
-                "content": {"creator": "@2:domain", "other_field": "here"},
+                "content": {"creator": "@2:domain", "other_key": "foo"},
             },
             {
                 "type": "m.room.create",
@@ -129,6 +185,68 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # After MSC2176, create events get nothing redacted.
+        self.run_test(
+            {"type": "m.room.create", "content": {"not_a_real_key": True}},
+            {
+                "type": "m.room.create",
+                "content": {"not_a_real_key": True},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
+    def test_power_levels(self):
+        """Power level events keep a variety of content keys."""
+        self.run_test(
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    "invite": 3,
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                    "other_key": 8,
+                },
+            },
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    # Note that invite is not here.
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                },
+                "signatures": {},
+                "unsigned": {},
+            },
+        )
+
+        # After MSC2176, power levels events keep the invite key.
+        self.run_test(
+            {"type": "m.room.power_levels", "content": {"invite": 75}},
+            {
+                "type": "m.room.power_levels",
+                "content": {"invite": 75},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
     def test_alias_event(self):
         """Alias events have special behavior up through room version 6."""
         self.run_test(
@@ -146,8 +264,7 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
-    def test_msc2432_alias_event(self):
-        """After MSC2432, alias events have no special behavior."""
+        # After MSC2432, alias events have no special behavior.
         self.run_test(
             {"type": "m.room.aliases", "content": {"aliases": ["test"]}},
             {
@@ -159,6 +276,32 @@ class PruneEventTestCase(unittest.TestCase):
             room_version=RoomVersions.V6,
         )
 
+    def test_redacts(self):
+        """Redaction events have no special behaviour until MSC2174/MSC2176."""
+
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.V6,
+        )
+
+        # After MSC2174, redaction events keep the redacts content key.
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {"redacts": "$test2:domain"},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
 
 class SerializeEventTestCase(unittest.TestCase):
     def serialize(self, ev, fields):
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index bd7a1b6891..c37bb6440e 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -118,4 +118,4 @@ class CasHandlerTestCase(HomeserverTestCase):
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "get_user_agent"])
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 368d600b33..2abd7a83b5 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import json
 import re
-from typing import Dict
+from typing import Dict, Optional
 from urllib.parse import parse_qs, urlencode, urlparse
 
 from mock import ANY, Mock, patch
@@ -24,7 +24,6 @@ import pymacaroons
 from twisted.web.resource import Resource
 
 from synapse.api.errors import RedirectException
-from synapse.handlers.oidc_handler import OidcError
 from synapse.handlers.sso import MappingException
 from synapse.rest.client.v1 import login
 from synapse.rest.synapse.client.pick_username import pick_username_resource
@@ -34,6 +33,14 @@ from synapse.types import UserID
 from tests.test_utils import FakeResponse, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+try:
+    import authlib  # noqa: F401
+
+    HAS_OIDC = True
+except ImportError:
+    HAS_OIDC = False
+
+
 # These are a few constants that are used as config parameters in the tests.
 ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
@@ -113,6 +120,9 @@ async def get_json(url):
 
 
 class OidcHandlerTestCase(HomeserverTestCase):
+    if not HAS_OIDC:
+        skip = "requires OIDC"
+
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
@@ -339,9 +349,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         cookie = args[1]
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._get_value_from_macaroon(macaroon, "state")
-        nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
-        redirect = self.handler._get_value_from_macaroon(
+        state = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "state"
+        )
+        nonce = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "nonce"
+        )
+        redirect = self.handler._token_generator._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
 
@@ -401,12 +415,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
-        )
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
         request = _build_callback_request(
             code, state, session, user_agent=user_agent, ip_address=ip_address
         )
@@ -458,6 +467,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
+        from synapse.handlers.oidc_handler import OidcError
+
         self.handler._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
@@ -488,11 +499,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_session")
 
         # Mismatching session
-        session = self.handler._generate_oidc_session_token(
-            state="state",
-            nonce="nonce",
-            client_redirect_url="http://client/redirect",
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -538,6 +546,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 body=b'{"error": "foo", "error_description": "bar"}',
             )
         )
+        from synapse.handlers.oidc_handler import OidcError
+
         exc = self.get_failure(self.handler._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
@@ -609,11 +619,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
         )
         request = _build_callback_request("code", state, session)
 
@@ -827,8 +834,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
+    def _generate_oidc_session_token(
+        self,
+        state: str,
+        nonce: str,
+        client_redirect_url: str,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        from synapse.handlers.oidc_handler import OidcSessionData
+
+        return self.handler._token_generator.generate_oidc_session_token(
+            state=state,
+            session_data=OidcSessionData(
+                nonce=nonce,
+                client_redirect_url=client_redirect_url,
+                ui_auth_session_id=ui_auth_session_id,
+            ),
+        )
+
 
 class UsernamePickerTestCase(HomeserverTestCase):
+    if not HAS_OIDC:
+        skip = "requires OIDC"
+
     servlets = [login.register_servlets]
 
     def default_config(self):
@@ -948,17 +976,19 @@ async def _make_callback_with_userinfo(
         userinfo: the OIDC userinfo dict
         client_redirect_url: the URL to redirect to on success.
     """
+    from synapse.handlers.oidc_handler import OidcSessionData
+
     handler = hs.get_oidc_handler()
     handler._exchange_code = simple_async_mock(return_value={})
     handler._parse_id_token = simple_async_mock(return_value=userinfo)
     handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
     state = "state"
-    session = handler._generate_oidc_session_token(
+    session = handler._token_generator.generate_oidc_session_token(
         state=state,
-        nonce="nonce",
-        client_redirect_url=client_redirect_url,
-        ui_auth_session_id=None,
+        session_data=OidcSessionData(
+            nonce="nonce", client_redirect_url=client_redirect_url,
+        ),
     )
     request = _build_callback_request("code", state, session)
 
@@ -994,7 +1024,7 @@ def _build_callback_request(
             "addCookie",
             "requestHeaders",
             "getClientIP",
-            "get_user_agent",
+            "getHeader",
         ]
     )
 
@@ -1003,5 +1033,4 @@ def _build_callback_request(
     request.args[b"code"] = [code.encode("utf-8")]
     request.args[b"state"] = [state.encode("utf-8")]
     request.getClientIP.return_value = ip_address
-    request.get_user_agent.return_value = user_agent
     return request
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 919547556b..022943a10a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -105,6 +105,21 @@ class ProfileTestCase(unittest.TestCase):
             "Frank",
         )
 
+        # Set displayname to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), ""
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
@@ -223,6 +238,21 @@ class ProfileTestCase(unittest.TestCase):
             "http://my.server/me.png",
         )
 
+        # Set avatar to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank, synapse.types.create_requester(self.frank), "",
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
+        )
+
     @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 548038214b..261c7083d1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -262,4 +262,4 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "get_user_agent"])
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 212484a7fe..9c52c8fdca 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         f = self.failureResultOf(test_d)
-        self.assertIsInstance(f.value, ValueError)
+        self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 0504cd187e..586b877bda 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -156,7 +154,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.hs = hs
 
         # Allow for uploading and downloading to/from the media repo
         self.media_repo = hs.get_media_repository_resource()
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index aa389df12f..d0090faa4f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -371,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index c2b998cdae..51a7731693 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
@@ -181,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index fa620f97f3..a0f32c5512 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -605,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         # Create user
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 73f8a8ec99..f48be3d65a 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9b2e4765f6..04599c2fcf 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -25,6 +25,7 @@ from mock import Mock
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.api.room_versions import RoomVersions
 from synapse.rest.client.v1 import login, logout, profile, room
 from synapse.rest.client.v2_alpha import devices, sync
 
@@ -587,6 +588,200 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         _search_test(None, "bar", "user_id")
 
 
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass", displayname="User1")
+        self.other_user_token = self.login("user", "pass")
+        self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+            self.other_user
+        )
+        self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote(
+            self.other_user
+        )
+
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+    def test_no_auth(self):
+        """
+        Try to deactivate users without authentication.
+        """
+        channel = self.make_request("POST", self.url, b"{}")
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        url = "/_synapse/admin/v1/deactivate/@bob:test"
+
+        channel = self.make_request("POST", url, access_token=self.other_user_token)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+        channel = self.make_request(
+            "POST", url, access_token=self.other_user_token, content=b"{}",
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+    def test_user_does_not_exist(self):
+        """
+        Tests that deactivation for a user that does not exist returns a 404
+        """
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/deactivate/@unknown_person:test",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_erase_is_not_bool(self):
+        """
+        If parameter `erase` is not boolean, return an error
+        """
+        body = json.dumps({"erase": "False"})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that deactivation for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual("Can only deactivate local users", channel.json_body["error"])
+
+    def test_deactivate_user_erase_true(self):
+        """
+        Test deactivating an user and set `erase` to `true`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": True})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertIsNone(channel.json_body["avatar_url"])
+        self.assertIsNone(channel.json_body["displayname"])
+
+        self._is_erased("@user:test", True)
+
+    def test_deactivate_user_erase_false(self):
+        """
+        Test deactivating an user and set `erase` to `false`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": False})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        self._is_erased("@user:test", False)
+
+    def _is_erased(self, user_id: str, expect: bool) -> None:
+        """Assert that the user is erased or not
+        """
+        d = self.store.is_user_erased(user_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertFalse(self.get_success(d))
+
+
 class UserRestTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -986,6 +1181,26 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Test deactivating another user.
         """
 
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
+
         # Deactivate user
         body = json.dumps({"deactivated": True})
 
@@ -999,6 +1214,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
         # the user is deactivated, the threepid will be deleted
 
         # Get user
@@ -1009,6 +1227,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
 
     @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
     def test_change_name_deactivate_user_user_directory(self):
@@ -1204,8 +1425,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -1236,24 +1455,26 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
 
     def test_user_does_not_exist(self):
         """
-        Tests that a lookup for a user that does not exist returns a 404
+        Tests that a lookup for a user that does not exist returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
         channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_user_is_not_local(self):
         """
-        Tests that a lookup for a user that is not a local returns a 400
+        Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
 
         channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_no_memberships(self):
         """
@@ -1284,6 +1505,49 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(number_rooms, channel.json_body["total"])
         self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
 
+    def test_get_rooms_with_nonlocal_user(self):
+        """
+        Tests that a normal lookup for rooms is successful with a non-local user
+        """
+
+        other_user_tok = self.login("user", "pass")
+        event_builder_factory = self.hs.get_event_builder_factory()
+        event_creation_handler = self.hs.get_event_creation_handler()
+        storage = self.hs.get_storage()
+
+        # Create two rooms, one with a local user only and one with both a local
+        # and remote user.
+        self.helper.create_room_as(self.other_user, tok=other_user_tok)
+        local_and_remote_room_id = self.helper.create_room_as(
+            self.other_user, tok=other_user_tok
+        )
+
+        # Add a remote user to the room.
+        builder = event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": "m.room.member",
+                "sender": "@joiner:remote_hs",
+                "state_key": "@joiner:remote_hs",
+                "room_id": local_and_remote_room_id,
+                "content": {"membership": "join"},
+            },
+        )
+
+        event, context = self.get_success(
+            event_creation_handler.create_new_client_event(builder)
+        )
+
+        self.get_success(storage.persistence.persist_event(event, context))
+
+        # Now get rooms
+        url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(1, channel.json_body["total"])
+        self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
+
 
 class PushersRestTestCase(unittest.HomeserverTestCase):
 
@@ -1401,7 +1665,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -1868,8 +2131,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 18932d7518..f9b8011961 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,19 +1,68 @@
-import json
+# -*- coding: utf-8 -*-
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import time
 import urllib.parse
+from html.parser import HTMLParser
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 
 from mock import Mock
 
-import jwt
+import pymacaroons
+
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.types import create_requester
 
 from tests import unittest
-from tests.unittest import override_config
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.handlers.test_saml import has_saml2
+from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.unittest import override_config, skip_unless
+
+try:
+    import jwt
+
+    HAS_JWT = True
+except ImportError:
+    HAS_JWT = False
+
+
+# public_base_url used in some tests
+BASE_URL = "https://synapse/"
+
+# CAS server used in some tests
+CAS_SERVER = "https://fake.test"
+
+# just enough to tell pysaml2 where to redirect to
+SAML_SERVER = "https://test.saml.server/idp/sso"
+TEST_SAML_METADATA = """
+<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
+  <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
+      <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
+  </md:IDPSSODescriptor>
+</md:EntityDescriptor>
+""" % {
+    "SAML_SERVER": SAML_SERVER,
+}
 
 LOGIN_URL = b"/_matrix/client/r0/login"
 TEST_URL = b"/_matrix/client/r0/account/whoami"
@@ -311,6 +360,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
 
+@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+class MultiSSOTestCase(unittest.HomeserverTestCase):
+    """Tests for homeservers with multiple SSO providers enabled"""
+
+    servlets = [
+        login.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+
+        config["public_baseurl"] = BASE_URL
+
+        config["cas_config"] = {
+            "enabled": True,
+            "server_url": CAS_SERVER,
+            "service_url": "https://matrix.goodserver.com:8448",
+        }
+
+        config["saml2_config"] = {
+            "sp_config": {
+                "metadata": {"inline": [TEST_SAML_METADATA]},
+                # use the XMLSecurity backend to avoid relying on xmlsec1
+                "crypto_backend": "XMLSecurity",
+            },
+        }
+
+        config["oidc_config"] = TEST_OIDC_CONFIG
+
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
+        return d
+
+    def test_multi_sso_redirect(self):
+        """/login/sso/redirect should redirect to an identity picker"""
+        client_redirect_url = "https://x?<abc>"
+
+        # first hit the redirect url, which should redirect to our idp picker
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        uri = channel.headers.getRawHeaders("Location")[0]
+
+        # hitting that picker should give us some HTML
+        channel = self.make_request("GET", uri)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # parse the form to check it has fields assumed elsewhere in this class
+        class FormPageParser(HTMLParser):
+            def __init__(self):
+                super().__init__()
+
+                # the values of the hidden inputs: map from name to value
+                self.hiddens = {}  # type: Dict[str, Optional[str]]
+
+                # the values of the radio buttons
+                self.radios = []  # type: List[Optional[str]]
+
+            def handle_starttag(
+                self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+            ) -> None:
+                attr_dict = dict(attrs)
+                if tag == "input":
+                    if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
+                        self.radios.append(attr_dict["value"])
+                    elif attr_dict["type"] == "hidden":
+                        input_name = attr_dict["name"]
+                        assert input_name
+                        self.hiddens[input_name] = attr_dict["value"]
+
+            def error(_, message):
+                self.fail(message)
+
+        p = FormPageParser()
+        p.feed(channel.result["body"].decode("utf-8"))
+        p.close()
+
+        self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
+
+        self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
+
+    def test_multi_sso_redirect_to_cas(self):
+        """If CAS is chosen, should redirect to the CAS server"""
+        client_redirect_url = "https://x?<abc>"
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        cas_uri = channel.headers.getRawHeaders("Location")[0]
+        cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
+
+        # it should redirect us to the login page of the cas server
+        self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
+
+        # check that the redirectUrl is correctly encoded in the service param - ie, the
+        # place that CAS will redirect to
+        cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
+        service_uri = cas_uri_params["service"][0]
+        _, service_uri_query = service_uri.split("?", 1)
+        service_uri_params = urllib.parse.parse_qs(service_uri_query)
+        self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
+
+    def test_multi_sso_redirect_to_saml(self):
+        """If SAML is chosen, should redirect to the SAML server"""
+        client_redirect_url = "https://x?<abc>"
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + client_redirect_url
+            + "&idp=saml",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        saml_uri = channel.headers.getRawHeaders("Location")[0]
+        saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
+
+        # it should redirect us to the login page of the SAML server
+        self.assertEqual(saml_uri_path, SAML_SERVER)
+
+        # the RelayState is used to carry the client redirect url
+        saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
+        relay_state_param = saml_uri_params["RelayState"][0]
+        self.assertEqual(relay_state_param, client_redirect_url)
+
+    def test_multi_sso_redirect_to_oidc(self):
+        """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
+        client_redirect_url = "https://x?<abc>"
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + client_redirect_url
+            + "&idp=oidc",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+        # ... and should have set a cookie including the redirect url
+        cookies = dict(
+            h.split(";")[0].split("=", maxsplit=1)
+            for h in channel.headers.getRawHeaders("Set-Cookie")
+        )
+
+        oidc_session_cookie = cookies["oidc_session"]
+        macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
+        self.assertEqual(
+            self._get_value_from_macaroon(macaroon, "client_redirect_url"),
+            client_redirect_url,
+        )
+
+    def test_multi_sso_redirect_to_unknown(self):
+        """An unknown IdP should cause a 400"""
+        channel = self.make_request(
+            "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+    @staticmethod
+    def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise ValueError("No %s caveat in macaroon" % (key,))
+
+
 class CASTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -324,7 +551,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         config = self.default_config()
         config["cas_config"] = {
             "enabled": True,
-            "server_url": "https://fake.test",
+            "server_url": CAS_SERVER,
             "service_url": "https://matrix.goodserver.com:8448",
         }
 
@@ -385,7 +612,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":
@@ -410,8 +637,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_cas_redirect_whitelisted(self):
-        """Tests that the SSO login flow serves a redirect to a whitelisted url
-        """
+        """Tests that the SSO login flow serves a redirect to a whitelisted url"""
         self._test_redirect("https://legit-site.com/")
 
     @override_config({"public_baseurl": "https://example.com"})
@@ -442,7 +668,9 @@ class CASTestCase(unittest.HomeserverTestCase):
 
         # Deactivate the account.
         self.get_success(
-            self.deactivate_account_handler.deactivate_account(self.user_id, False)
+            self.deactivate_account_handler.deactivate_account(
+                self.user_id, False, create_requester(self.user_id)
+            )
         )
 
         # Request the CAS ticket.
@@ -459,6 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.assertIn(b"SSO account deactivated", channel.result["body"])
 
 
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -475,17 +704,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = self.jwt_algorithm
         return self.hs
 
-    def jwt_encode(self, token: str, secret: str = jwt_secret) -> str:
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(token, secret, self.jwt_algorithm)
+        result = jwt.encode(
+            payload, secret, self.jwt_algorithm
+        )  # type: Union[str, bytes]
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
@@ -617,7 +846,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         )
 
     def test_login_no_token(self):
-        params = json.dumps({"type": "org.matrix.login.jwt"})
+        params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -627,6 +856,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
 # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
 # signed by the private key.
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTPubKeyTestCase(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
@@ -684,17 +914,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = "RS256"
         return self.hs
 
-    def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str:
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
         # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
-        result = jwt.encode(token, secret, "RS256")
+        result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str]
         if isinstance(result, bytes):
             return result.decode("ascii")
         return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
@@ -764,8 +992,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def test_login_appservice_user(self):
-        """Test that an appservice user can use /login
-        """
+        """Test that an appservice user can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -779,8 +1006,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_user_bot(self):
-        """Test that the appservice bot can use /login
-        """
+        """Test that the appservice bot can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -794,8 +1020,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_wrong_user(self):
-        """Test that non-as users cannot login with the as token
-        """
+        """Test that non-as users cannot login with the as token"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -809,8 +1034,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
     def test_login_appservice_wrong_as(self):
-        """Test that as users cannot login with wrong as token
-        """
+        """Test that as users cannot login with wrong as token"""
         self.register_as_user(AS_USER)
 
         params = {
@@ -825,7 +1049,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_login_appservice_no_token(self):
         """Test that users must provide a token when using the appservice
-           login method
+        login method
         """
         self.register_as_user(AS_USER)
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 6105eac47c..d4e3165436 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -29,7 +29,7 @@ from synapse.handlers.pagination import PurgeStatus
 from synapse.rest import admin
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias, UserID
+from synapse.types import JsonDict, RoomAlias, UserID, create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -1687,7 +1687,9 @@ class ContextTestCase(unittest.HomeserverTestCase):
 
         deactivate_account_handler = self.hs.get_deactivate_account_handler()
         self.get_success(
-            deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+            deactivate_account_handler.deactivate_account(
+                self.user_id, True, create_requester(self.user_id)
+            )
         )
 
         # Invite another user in the room. This is needed because messages will be
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index dbc27893b5..81b7f84360 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -444,6 +444,7 @@ class RestHelper:
 
 
 # an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
 TEST_OIDC_CONFIG = {
     "enabled": True,
     "discover": False,
@@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = {
     "client_id": "test-client-id",
     "client_secret": "test-client-secret",
     "scopes": ["profile"],
-    "authorization_endpoint": "https://z",
+    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
     "token_endpoint": "https://issuer.test/token",
     "userinfo_endpoint": "https://issuer.test/userinfo",
     "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index ac66a4e0b7..bb91e0c331 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -26,8 +26,10 @@ from synapse.rest.oidc import OIDCResource
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
+from tests.handlers.test_oidc import HAS_OIDC
 from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
+from tests.unittest import override_config, skip_unless
 
 
 class DummyRecaptchaChecker(UserInteractiveAuthChecker):
@@ -158,20 +160,22 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     def default_config(self):
         config = super().default_config()
+        config["public_baseurl"] = "https://synapse.test"
 
-        # we enable OIDC as a way of testing SSO flows
-        oidc_config = {}
-        oidc_config.update(TEST_OIDC_CONFIG)
-        oidc_config["allow_existing_users"] = True
+        if HAS_OIDC:
+            # we enable OIDC as a way of testing SSO flows
+            oidc_config = {}
+            oidc_config.update(TEST_OIDC_CONFIG)
+            oidc_config["allow_existing_users"] = True
+            config["oidc_config"] = oidc_config
 
-        config["oidc_config"] = oidc_config
-        config["public_baseurl"] = "https://synapse.test"
         return config
 
     def create_resource_dict(self):
         resource_dict = super().create_resource_dict()
-        # mount the OIDC resource at /_synapse/oidc
-        resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+        if HAS_OIDC:
+            # mount the OIDC resource at /_synapse/oidc
+            resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
         return resource_dict
 
     def prepare(self, reactor, clock, hs):
@@ -380,6 +384,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Note that *no auth* information is provided, not even a session iD!
         self.delete_device(self.user_tok, self.device_id, 200)
 
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_does_not_offer_password_for_sso_user(self):
         login_resp = self.helper.login_via_oidc("username")
         user_tok = login_resp["access_token"]
@@ -393,13 +399,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
 
     def test_does_not_offer_sso_for_password_user(self):
-        # now call the device deletion API: we should get the option to auth with SSO
-        # and not password.
         channel = self.delete_device(self.user_tok, self.device_id, 401)
 
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.password"]}])
 
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_offers_both_flows_for_upgraded_user(self):
         """A user that had a password and then logged in with SSO should get both flows
         """
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 83d728b4a4..6968502433 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -26,8 +26,15 @@ from twisted.test.proto_helpers import AccumulatingProtocol
 from tests import unittest
 from tests.server import FakeTransport
 
+try:
+    import lxml
+except ImportError:
+    lxml = None
+
 
 class URLPreviewTests(unittest.HomeserverTestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
 
     hijack_auth = True
     user_id = "@test:user"
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
new file mode 100644
index 0000000000..673e1fe3e3
--- /dev/null
+++ b/tests/storage/test_account_data.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Set
+
+from synapse.api.constants import AccountDataTypes
+
+from tests import unittest
+
+
+class IgnoredUsersTestCase(unittest.HomeserverTestCase):
+    def prepare(self, hs, reactor, clock):
+        self.store = self.hs.get_datastore()
+        self.user = "@user:test"
+
+    def _update_ignore_list(
+        self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+    ) -> None:
+        """Update the account data to block the given users."""
+        if ignorer_user_id is None:
+            ignorer_user_id = self.user
+
+        self.get_success(
+            self.store.add_account_data_for_user(
+                ignorer_user_id,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": {u: {} for u in ignored_user_ids}},
+            )
+        )
+
+    def assert_ignorers(
+        self, ignored_user_id: str, expected_ignorer_user_ids: Set[str]
+    ) -> None:
+        self.assertEqual(
+            self.get_success(self.store.ignored_by(ignored_user_id)),
+            expected_ignorer_user_ids,
+        )
+
+    def test_ignoring_users(self):
+        """Basic adding/removing of users from the ignore list."""
+        self._update_ignore_list("@other:test", "@another:remote")
+
+        # Check a user which no one ignores.
+        self.assert_ignorers("@user:test", set())
+
+        # Check a local user which is ignored.
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Check a remote user which is ignored.
+        self.assert_ignorers("@another:remote", {self.user})
+
+        # Add one user, remove one user, and leave one user.
+        self._update_ignore_list("@foo:test", "@another:remote")
+
+        # Check the removed user.
+        self.assert_ignorers("@other:test", set())
+
+        # Check the added user.
+        self.assert_ignorers("@foo:test", {self.user})
+
+        # Check the removed user.
+        self.assert_ignorers("@another:remote", {self.user})
+
+    def test_caching(self):
+        """Ensure that caching works properly between different users."""
+        # The first user ignores a user.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # The second user ignores them.
+        self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+        self.assert_ignorers("@other:test", {self.user, "@second:test"})
+
+        # The first user un-ignores them.
+        self._update_ignore_list()
+        self.assert_ignorers("@other:test", {"@second:test"})
+
+    def test_invalid_data(self):
+        """Invalid data ends up clearing out the ignored users list."""
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # No ignored_users key.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
+
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Invalid data.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": "unexpected"},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..83c377824b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,472 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.storage.databases.main.events import _LinkMap
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self._next_stream_ordering = 1
+
+    def test_simple(self):
+        """Test that the example in `docs/auth_chain_difference_algorithm.md`
+        works.
+        """
+
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        power_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        bob_join_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[
+                    create.event_id,
+                    alice_join.event_id,
+                    power_2.event_id,
+                ],
+            )
+        )
+
+        events = [
+            create,
+            bob_join,
+            power,
+            alice_invite,
+            alice_join,
+            bob_join_2,
+            power_2,
+            alice_join2,
+        ]
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+            (bob_join_2, power),
+            (alice_join2, power_2),
+        ]
+
+        self.persist(events)
+        chain_map, link_map = self.fetch_chains(events)
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+        # Test that everything can reach the create event, but the create event
+        # can't reach anything.
+        for event in events[1:]:
+            self.assertTrue(
+                link_map.exists_path_from(
+                    chain_map[event.event_id], chain_map[create.event_id]
+                ),
+            )
+
+            self.assertFalse(
+                link_map.exists_path_from(
+                    chain_map[create.event_id], chain_map[event.event_id],
+                ),
+            )
+
+    def test_out_of_order_events(self):
+        """Test that we handle persisting events that we don't have the full
+        auth chain for yet (which should only happen for out of band memberships).
+        """
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        # First persist the base room.
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        self.persist([create, bob_join, power])
+
+        # Now persist an invite and a couple of memberships out of order.
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+            )
+        )
+
+        self.persist([alice_join])
+        self.persist([alice_join2])
+        self.persist([alice_invite])
+
+        # The end result should be sane.
+        events = [create, bob_join, power, alice_invite, alice_join]
+
+        chain_map, link_map = self.fetch_chains(events)
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+        ]
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+    def persist(
+        self, events: List[EventBase],
+    ):
+        """Persist the given events and check that the links generated match
+        those given.
+        """
+
+        persist_events_store = self.hs.get_datastores().persist_events
+
+        for e in events:
+            e.internal_metadata.stream_ordering = self._next_stream_ordering
+            self._next_stream_ordering += 1
+
+        def _persist(txn):
+            # We need to persist the events to the events and state_events
+            # tables.
+            persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+            # Actually call the function that calculates the auth chain stuff.
+            persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+        self.get_success(
+            persist_events_store.db_pool.runInteraction("_persist", _persist,)
+        )
+
+    def fetch_chains(
+        self, events: List[EventBase]
+    ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+        # Fetch the map from event ID -> (chain ID, sequence number)
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chains",
+                column="event_id",
+                iterable=[e.event_id for e in events],
+                retcols=("event_id", "chain_id", "sequence_number"),
+                keyvalues={},
+            )
+        )
+
+        chain_map = {
+            row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+        }
+
+        # Fetch all the links and pass them to the _LinkMap.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chain_links",
+                column="origin_chain_id",
+                iterable=[chain_id for chain_id, _ in chain_map.values()],
+                retcols=(
+                    "origin_chain_id",
+                    "origin_sequence_number",
+                    "target_chain_id",
+                    "target_sequence_number",
+                ),
+                keyvalues={},
+            )
+        )
+
+        link_map = _LinkMap()
+        for row in rows:
+            added = link_map.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+            )
+
+            # We shouldn't have persisted any redundant links
+            self.assertTrue(added)
+
+        return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+    def test_simple(self):
+        """Basic tests for the LinkMap.
+        """
+        link_map = _LinkMap()
+
+        link_map.add_link((1, 1), (2, 1), new=False)
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+        self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+        self.assertCountEqual(link_map.get_additions(), [])
+        self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+        self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+        # Attempting to add a redundant link is ignored.
+        self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+        # Adding new non-redundant links works
+        self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+        self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 482506d731..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
 import tests.unittest
 import tests.utils
 
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    def test_auth_difference(self):
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
+        # Mark the room as not having a cover index
+
+        def store_room(txn):
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": use_chain_cover_index,
+                },
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn, event_id, stream_ordering):
+        def insert_event(txn):
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn,
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                ],
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
+        self.assertSetEqual(difference, set())
+
+    def test_auth_difference_partial_cover(self):
+        """Test that we correctly handle rooms where not all events have a chain
+        cover calculated. This can happen in some obscure edge cases, including
+        during the background update that calculates the chain cover for old
+        rooms.
+        """
+
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ´ |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
+
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
 
-            depth = depth_map[event_id]
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
 
+        def insert_event(txn):
+            # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
-                table="events",
-                values={
-                    "event_id": event_id,
+                "rooms",
+                {
                     "room_id": room_id,
-                    "depth": depth,
-                    "topological_ordering": depth,
-                    "type": "m.test",
-                    "processed": True,
-                    "outlier": False,
-                    "stream_ordering": stream_ordering,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": True,
                 },
             )
 
-            self.store.db_pool.simple_insert_many_txn(
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            # Insert all events apart from 'B'
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                table="event_auth",
-                values=[
-                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
-                    for a in auth_graph[event_id]
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                    if event_id != "b"
                 ],
             )
 
-        next_stream_ordering = 0
-        for event_id in auth_graph:
-            next_stream_ordering += 1
-            self.get_success(
-                self.store.db_pool.runInteraction(
-                    "insert", insert_event, event_id, next_stream_ordering
-                )
+            # Now we insert the event 'B' without a chain cover, by temporarily
+            # pretending the room doesn't have a chain cover.
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+            )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+            )
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
             )
 
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
@@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.store.get_auth_chain_difference(room_id, [{"a"}])
         )
         self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+    event_id = attr.ib()
+    room_id = attr.ib()
+    auth_events = attr.ib()
+
+    type = "foo"
+    state_key = "foo"
+
+    internal_metadata = _EventInternalMetadata({})
+
+    def auth_event_ids(self):
+        return self.auth_events
+
+    def is_state(self):
+        return True
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..ea63bd56b4 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase):
             ),
         )
 
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.u_frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_avatar_url(self):
         yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
@@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase):
                 )
             ),
         )
+
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                )
+            )
+        )
diff --git a/tests/test_preview.py b/tests/test_preview.py
index a883d707df..c19facc1cb 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -20,8 +20,16 @@ from synapse.rest.media.v1.preview_url_resource import (
 
 from . import unittest
 
+try:
+    import lxml
+except ImportError:
+    lxml = None
+
 
 class PreviewTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_long_summarize(self):
         example_paras = [
             """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
@@ -137,6 +145,9 @@ class PreviewTestCase(unittest.TestCase):
 
 
 class PreviewUrlTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_simple(self):
         html = """
         <html>
diff --git a/tests/unittest.py b/tests/unittest.py
index af7f752c5a..bbd295687c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock, patch
 
@@ -736,3 +736,29 @@ def override_config(extra_config):
         return func
 
     return decorator
+
+
+TV = TypeVar("TV")
+
+
+def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
+    """A test decorator which will skip the decorated test unless a condition is set
+
+    For example:
+
+    class MyTestCase(TestCase):
+        @skip_unless(HAS_FOO, "Cannot test without foo")
+        def test_foo(self):
+            ...
+
+    Args:
+        condition: If true, the test will be skipped
+        reason: the reason to give for skipping the test
+    """
+
+    def decorator(f: TV) -> TV:
+        if not condition:
+            f.skip = reason  # type: ignore
+        return f
+
+    return decorator
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index dadfabd46d..ecd9efc4df 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -25,13 +25,8 @@ from tests.unittest import TestCase
 class DeferredCacheTestCase(TestCase):
     def test_empty(self):
         cache = DeferredCache("test")
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get("foo")
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_hit(self):
         cache = DeferredCache("test")
@@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(("foo",), 123)
         cache.invalidate(("foo",))
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(("foo",))
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_invalidate_all(self):
         cache = DeferredCache("testcache")
@@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(2, "two")
         cache.prefill(3, "three")  # 1 will be evicted
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(1)
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
         cache.get(2)
         cache.get(3)
@@ -239,13 +224,55 @@ class DeferredCacheTestCase(TestCase):
 
         cache.prefill(3, "three")
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(2)
-        except KeyError:
-            failed = True
 
-        self.assertTrue(failed)
+        cache.get(1)
+        cache.get(3)
+
+    def test_eviction_iterable(self):
+        cache = DeferredCache(
+            "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
+        )
+
+        cache.prefill(1, ["one", "two"])
+        cache.prefill(2, ["three"])
 
+        # Now access 1 again, thus causing 2 to be least-recently used
+        cache.get(1)
+
+        # Now add an item to the cache, which evicts 2.
+        cache.prefill(3, ["four"])
+        with self.assertRaises(KeyError):
+            cache.get(2)
+
+        # Ensure 1 & 3 are in the cache.
         cache.get(1)
         cache.get(3)
+
+        # Now access 1 again, thus causing 3 to be least-recently used
+        cache.get(1)
+
+        # Now add an item with multiple elements to the cache
+        cache.prefill(4, ["five", "six"])
+
+        # Both 1 and 3 are evicted since there's too many elements.
+        with self.assertRaises(KeyError):
+            cache.get(1)
+        with self.assertRaises(KeyError):
+            cache.get(3)
+
+        # Now add another item to fill the cache again.
+        cache.prefill(5, ["seven"])
+
+        # Now access 4, thus causing 5 to be least-recently used
+        cache.get(4)
+
+        # Add an empty item.
+        cache.prefill(6, [])
+
+        # 5 gets evicted and replaced since an empty element counts as an item.
+        with self.assertRaises(KeyError):
+            cache.get(5)
+        cache.get(4)
+        cache.get(6)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1184cea5a3 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
 
 from tests.unittest import TestCase
 
@@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase):
         self.assertEqual(
             list(parts), [],
         )
+
+
+class SortTopologically(TestCase):
+    def test_empty(self):
+        "Test that an empty graph works correctly"
+
+        graph = {}  # type: Dict[int, List[int]]
+        self.assertEqual(list(sorted_topologically([], graph)), [])
+
+    def test_disconnected(self):
+        "Test that a graph with no edges work"
+
+        graph = {1: [], 2: []}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+    def test_linear(self):
+        "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_subset(self):
+        "Test that only sorting a subset of the graph works"
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+    def test_fork(self):
+        "Test that a forked graph works"
+        graph = {1: [], 2: [1], 3: [1], 4: [2, 3]}  # type: Dict[int, List[int]]
+
+        # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+        # always get the same one.
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
diff --git a/tox.ini b/tox.ini
index 8e8b495292..297136fcc5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,7 +2,6 @@
 envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort
 
 [base]
-extras = test
 deps =
     python-subunit
     junitxml
@@ -25,10 +24,29 @@ deps =
     # install the "enum34" dependency of cryptography.
     pip>=10
 
+# directories/files we run the linters on
+lint_targets =
+    setup.py
+    synapse
+    tests
+    scripts
+    scripts-dev
+    stubs
+    contrib
+    synctl
+    synmark
+    .buildkite
+    docker
+
+# default settings for all tox environments
 [testenv]
 deps =
     {[base]deps}
-extras = all, test
+extras =
+    # install the optional dependendencies for tox environments without
+    # '-noextras' in their name
+    !noextras: all
+    test
 
 setenv =
     # use a postgres db for tox environments with "-postgres" in the name
@@ -126,13 +144,13 @@ commands =
 [testenv:check_codestyle]
 extras = lint
 commands =
-    python -m black --check --diff .
-    /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}"
+    python -m black --check --diff {[base]lint_targets}
+    flake8 {[base]lint_targets} {env:PEP8SUFFIX:}
     {toxinidir}/scripts-dev/config-lint.sh
 
 [testenv:check_isort]
 extras = lint
-commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts"
+commands = isort -c --df --sp setup.cfg {[base]lint_targets}
 
 [testenv:check-newsfragment]
 skip_install = True