summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-x.ci/patch_for_twisted_trunk.sh8
-rw-r--r--.ci/twisted_trunk_build_failed_issue_template.md4
-rw-r--r--.github/workflows/twisted_trunk.yml90
-rw-r--r--changelog.d/10142.feature1
-rw-r--r--changelog.d/10192.doc1
-rw-r--r--changelog.d/10452.feature1
-rw-r--r--changelog.d/10524.feature1
-rw-r--r--changelog.d/10561.bugfix1
-rw-r--r--changelog.d/10593.bugfix1
-rw-r--r--changelog.d/10608.misc1
-rw-r--r--changelog.d/10613.feature1
-rw-r--r--changelog.d/10614.misc1
-rw-r--r--changelog.d/10615.misc1
-rw-r--r--changelog.d/10624.misc1
-rw-r--r--changelog.d/10627.misc1
-rw-r--r--changelog.d/10629.misc1
-rw-r--r--changelog.d/10630.misc1
-rw-r--r--changelog.d/10639.doc1
-rw-r--r--changelog.d/10640.misc1
-rw-r--r--changelog.d/10642.misc1
-rw-r--r--changelog.d/10644.bugfix1
-rw-r--r--changelog.d/10651.misc1
-rw-r--r--changelog.d/10654.bugfix1
-rw-r--r--changelog.d/10662.misc1
-rw-r--r--changelog.d/10664.misc1
-rw-r--r--changelog.d/10665.misc1
-rw-r--r--changelog.d/10666.misc1
-rw-r--r--changelog.d/10667.misc1
-rw-r--r--changelog.d/10672.misc1
-rw-r--r--changelog.d/10677.bugfix1
-rw-r--r--changelog.d/8830.removal1
-rw-r--r--docs/SUMMARY.md3
-rw-r--r--docs/admin_api/purge_room.md21
-rw-r--r--docs/admin_api/shutdown_room.md102
-rw-r--r--docs/admin_api/user_admin_api.md8
-rw-r--r--docs/modules.md46
-rw-r--r--docs/openid.md64
-rw-r--r--docs/presence_router_module.md6
-rw-r--r--docs/sample_config.yaml29
-rw-r--r--docs/upgrade.md22
-rw-r--r--docs/usage/administration/admin_api/registration_tokens.md295
-rw-r--r--docs/workers.md1
-rw-r--r--mypy.ini9
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/api/errors.py8
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/app/generic_worker.py8
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/config/ratelimiting.py11
-rw-r--r--synapse/config/registration.py15
-rw-r--r--synapse/config/server.py15
-rw-r--r--synapse/events/presence_router.py192
-rw-r--r--synapse/federation/federation_client.py7
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/handlers/auth.py23
-rw-r--r--synapse/handlers/federation.py387
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/register.py18
-rw-r--r--synapse/handlers/room_member.py17
-rw-r--r--synapse/handlers/room_summary.py76
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/sync.py198
-rw-r--r--synapse/handlers/ui_auth/__init__.py5
-rw-r--r--synapse/handlers/ui_auth/checkers.py75
-rw-r--r--synapse/http/additional_resource.py13
-rw-r--r--synapse/http/federation/srv_resolver.py45
-rw-r--r--synapse/http/proxyagent.py4
-rw-r--r--synapse/module_api/__init__.py10
-rw-r--r--synapse/res/templates/recaptcha.html3
-rw-r--r--synapse/res/templates/registration_token.html23
-rw-r--r--synapse/res/templates/terms.html3
-rw-r--r--synapse/rest/admin/__init__.py12
-rw-r--r--synapse/rest/admin/purge_room_servlet.py58
-rw-r--r--synapse/rest/admin/registration_tokens.py321
-rw-r--r--synapse/rest/admin/rooms.py35
-rw-r--r--synapse/rest/admin/users.py55
-rw-r--r--synapse/rest/client/account_validity.py39
-rw-r--r--synapse/rest/client/auth.py63
-rw-r--r--synapse/rest/client/capabilities.py14
-rw-r--r--synapse/rest/client/directory.py78
-rw-r--r--synapse/rest/client/keys.py16
-rw-r--r--synapse/rest/client/login.py29
-rw-r--r--synapse/rest/client/pusher.py22
-rw-r--r--synapse/rest/client/read_marker.py15
-rw-r--r--synapse/rest/client/register.py72
-rw-r--r--synapse/rest/client/room_upgrade_rest_servlet.py18
-rw-r--r--synapse/rest/client/shared_rooms.py16
-rw-r--r--synapse/rest/client/sync.py132
-rw-r--r--synapse/rest/client/tags.py25
-rw-r--r--synapse/rest/client/thirdparty.py36
-rw-r--r--synapse/rest/client/tokenrefresh.py14
-rw-r--r--synapse/rest/client/user_directory.py17
-rw-r--r--synapse/rest/client/versions.py14
-rw-r--r--synapse/rest/client/voip.py13
-rw-r--r--synapse/static/client/register/style.css6
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/registration.py341
-rw-r--r--synapse/storage/databases/main/roommember.py16
-rw-r--r--synapse/storage/databases/main/session.py145
-rw-r--r--synapse/storage/databases/main/ui_auth.py43
-rw-r--r--synapse/storage/databases/main/user_directory.py2
-rw-r--r--synapse/storage/roommember.py44
-rw-r--r--synapse/storage/schema/main/delta/62/02session_store.sql23
-rw-r--r--synapse/storage/schema/main/delta/63/01create_registration_tokens.sql23
-rw-r--r--tests/events/test_presence_router.py109
-rw-r--r--tests/handlers/test_sync.py137
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/rest/admin/test_device.py99
-rw-r--r--tests/rest/admin/test_registration_tokens.py710
-rw-r--r--tests/rest/admin/test_room.py162
-rw-r--r--tests/rest/admin/test_user.py62
-rw-r--r--tests/rest/client/test_account.py (renamed from tests/rest/client/v2_alpha/test_account.py)0
-rw-r--r--tests/rest/client/test_auth.py (renamed from tests/rest/client/v2_alpha/test_auth.py)2
-rw-r--r--tests/rest/client/test_capabilities.py (renamed from tests/rest/client/v2_alpha/test_capabilities.py)109
-rw-r--r--tests/rest/client/test_directory.py (renamed from tests/rest/client/v1/test_directory.py)0
-rw-r--r--tests/rest/client/test_events.py (renamed from tests/rest/client/v1/test_events.py)0
-rw-r--r--tests/rest/client/test_filter.py (renamed from tests/rest/client/v2_alpha/test_filter.py)0
-rw-r--r--tests/rest/client/test_keys.py91
-rw-r--r--tests/rest/client/test_login.py (renamed from tests/rest/client/v1/test_login.py)2
-rw-r--r--tests/rest/client/test_password_policy.py (renamed from tests/rest/client/v2_alpha/test_password_policy.py)0
-rw-r--r--tests/rest/client/test_presence.py (renamed from tests/rest/client/v1/test_presence.py)0
-rw-r--r--tests/rest/client/test_profile.py (renamed from tests/rest/client/v1/test_profile.py)0
-rw-r--r--tests/rest/client/test_push_rule_attrs.py (renamed from tests/rest/client/v1/test_push_rule_attrs.py)0
-rw-r--r--tests/rest/client/test_register.py (renamed from tests/rest/client/v2_alpha/test_register.py)434
-rw-r--r--tests/rest/client/test_relations.py (renamed from tests/rest/client/v2_alpha/test_relations.py)0
-rw-r--r--tests/rest/client/test_report_event.py (renamed from tests/rest/client/v2_alpha/test_report_event.py)0
-rw-r--r--tests/rest/client/test_rooms.py (renamed from tests/rest/client/v1/test_rooms.py)0
-rw-r--r--tests/rest/client/test_sendtodevice.py (renamed from tests/rest/client/v2_alpha/test_sendtodevice.py)0
-rw-r--r--tests/rest/client/test_shared_rooms.py (renamed from tests/rest/client/v2_alpha/test_shared_rooms.py)0
-rw-r--r--tests/rest/client/test_sync.py (renamed from tests/rest/client/v2_alpha/test_sync.py)0
-rw-r--r--tests/rest/client/test_typing.py (renamed from tests/rest/client/v1/test_typing.py)0
-rw-r--r--tests/rest/client/test_upgrade_room.py (renamed from tests/rest/client/v2_alpha/test_upgrade_room.py)0
-rw-r--r--tests/rest/client/utils.py (renamed from tests/rest/client/v1/utils.py)0
-rw-r--r--tests/rest/client/v1/__init__.py13
-rw-r--r--tests/rest/client/v2_alpha/__init__.py0
-rw-r--r--tests/test_federation.py10
-rw-r--r--tests/unittest.py2
137 files changed, 4442 insertions, 1175 deletions
diff --git a/.ci/patch_for_twisted_trunk.sh b/.ci/patch_for_twisted_trunk.sh
new file mode 100755
index 0000000000..f524581986
--- /dev/null
+++ b/.ci/patch_for_twisted_trunk.sh
@@ -0,0 +1,8 @@
+#!/bin/sh
+
+# replaces the dependency on Twisted in `python_dependencies` with trunk.
+
+set -e
+cd "$(dirname "$0")"/..
+
+sed -i -e 's#"Twisted.*"#"Twisted @ git+https://github.com/twisted/twisted"#' synapse/python_dependencies.py
diff --git a/.ci/twisted_trunk_build_failed_issue_template.md b/.ci/twisted_trunk_build_failed_issue_template.md
new file mode 100644
index 0000000000..2ead1dc394
--- /dev/null
+++ b/.ci/twisted_trunk_build_failed_issue_template.md
@@ -0,0 +1,4 @@
+---
+title: CI run against Twisted trunk is failing
+---
+See https://github.com/{{env.GITHUB_REPOSITORY}}/actions/runs/{{env.GITHUB_RUN_ID}}
diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml
new file mode 100644
index 0000000000..b5c729888f
--- /dev/null
+++ b/.github/workflows/twisted_trunk.yml
@@ -0,0 +1,90 @@
+name: Twisted Trunk
+
+on:
+  schedule:
+    - cron: 0 8 * * *
+
+  workflow_dispatch:
+
+jobs:
+  mypy:
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-python@v2
+      - run: .ci/patch_for_twisted_trunk.sh
+      - run: pip install tox
+      - run: tox -e mypy
+
+  trial:
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v2
+      - run: sudo apt-get -qq install xmlsec1
+      - uses: actions/setup-python@v2
+        with:
+          python-version: 3.6
+      - run: .ci/patch_for_twisted_trunk.sh
+      - run: pip install tox
+      - run: tox -e py
+        env:
+          TRIAL_FLAGS: "--jobs=2"
+
+      - name: Dump logs
+        # Note: Dumps to workflow logs instead of using actions/upload-artifact
+        #       This keeps logs colocated with failing jobs
+        #       It also ignores find's exit code; this is a best effort affair
+        run: >-
+          find _trial_temp -name '*.log'
+          -exec echo "::group::{}" \;
+          -exec cat {} \;
+          -exec echo "::endgroup::" \;
+          || true
+
+  sytest:
+    runs-on: ubuntu-latest
+    container:
+      image: matrixdotorg/sytest-synapse:buster
+      volumes:
+        - ${{ github.workspace }}:/src
+
+    steps:
+      - uses: actions/checkout@v2
+      - name: Patch dependencies
+        run: .ci/patch_for_twisted_trunk.sh
+        working-directory: /src
+      - name: Run SyTest
+        run: /bootstrap.sh synapse
+        working-directory: /src
+      - name: Summarise results.tap
+        if: ${{ always() }}
+        run: /sytest/scripts/tap_to_gha.pl /logs/results.tap
+      - name: Upload SyTest logs
+        uses: actions/upload-artifact@v2
+        if: ${{ always() }}
+        with:
+          name: Sytest Logs - ${{ job.status }} - (${{ join(matrix.*, ', ') }})
+          path: |
+            /logs/results.tap
+            /logs/**/*.log*
+
+  # open an issue if the build fails, so we know about it.
+  open-issue:
+    if: failure()
+    needs:
+      - mypy
+      - trial
+      - sytest
+
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v2
+      - uses: JasonEtco/create-an-issue@5d9504915f79f9cc6d791934b8ef34f2353dd74d # v2.5.0, 2020-12-06
+        env:
+          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+        with:
+          update_existing: true
+          filename: .ci/twisted_trunk_build_failed_issue_template.md
diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature
new file mode 100644
index 0000000000..5353f6269d
--- /dev/null
+++ b/changelog.d/10142.feature
@@ -0,0 +1 @@
+Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown.
diff --git a/changelog.d/10192.doc b/changelog.d/10192.doc
new file mode 100644
index 0000000000..3dd00537e8
--- /dev/null
+++ b/changelog.d/10192.doc
@@ -0,0 +1 @@
+Add documentation on how to connect Django with synapse using oidc and django-oauth-toolkit. Contributed by @HugoDelval.
diff --git a/changelog.d/10452.feature b/changelog.d/10452.feature
new file mode 100644
index 0000000000..f332b383e3
--- /dev/null
+++ b/changelog.d/10452.feature
@@ -0,0 +1 @@
+Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose enable_set_displayname in capabilities.
\ No newline at end of file
diff --git a/changelog.d/10524.feature b/changelog.d/10524.feature
new file mode 100644
index 0000000000..288c9bd74e
--- /dev/null
+++ b/changelog.d/10524.feature
@@ -0,0 +1 @@
+Port the PresenceRouter module interface to the new generic interface.
\ No newline at end of file
diff --git a/changelog.d/10561.bugfix b/changelog.d/10561.bugfix
new file mode 100644
index 0000000000..2e4f53508c
--- /dev/null
+++ b/changelog.d/10561.bugfix
@@ -0,0 +1 @@
+Display an error on User-Interactive Authentication fallback pages when authentication fails. Contributed by Callum Brown.
diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix
new file mode 100644
index 0000000000..af910bfa4d
--- /dev/null
+++ b/changelog.d/10593.bugfix
@@ -0,0 +1 @@
+Reject Client-Server `/keys/query` requests which provide `device_ids` incorrectly.
diff --git a/changelog.d/10608.misc b/changelog.d/10608.misc
new file mode 100644
index 0000000000..875bdd2fd0
--- /dev/null
+++ b/changelog.d/10608.misc
@@ -0,0 +1 @@
+Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel.
\ No newline at end of file
diff --git a/changelog.d/10613.feature b/changelog.d/10613.feature
new file mode 100644
index 0000000000..ffc4e4289c
--- /dev/null
+++ b/changelog.d/10613.feature
@@ -0,0 +1 @@
+Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946).
diff --git a/changelog.d/10614.misc b/changelog.d/10614.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10614.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/changelog.d/10615.misc b/changelog.d/10615.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10615.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/changelog.d/10624.misc b/changelog.d/10624.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10624.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/changelog.d/10627.misc b/changelog.d/10627.misc
new file mode 100644
index 0000000000..e6d314976e
--- /dev/null
+++ b/changelog.d/10627.misc
@@ -0,0 +1 @@
+Remove not needed database updates in modify user admin API.
\ No newline at end of file
diff --git a/changelog.d/10629.misc b/changelog.d/10629.misc
new file mode 100644
index 0000000000..cca1eb6c57
--- /dev/null
+++ b/changelog.d/10629.misc
@@ -0,0 +1 @@
+Convert room member storage tuples to `attrs` classes.
diff --git a/changelog.d/10630.misc b/changelog.d/10630.misc
new file mode 100644
index 0000000000..7d01e00e48
--- /dev/null
+++ b/changelog.d/10630.misc
@@ -0,0 +1 @@
+Use auto-attribs for the attrs classes used in sync.
diff --git a/changelog.d/10639.doc b/changelog.d/10639.doc
new file mode 100644
index 0000000000..acbac4aad8
--- /dev/null
+++ b/changelog.d/10639.doc
@@ -0,0 +1 @@
+Fix some of the titles not rendering in the OIDC documentation.
diff --git a/changelog.d/10640.misc b/changelog.d/10640.misc
new file mode 100644
index 0000000000..9a765435db
--- /dev/null
+++ b/changelog.d/10640.misc
@@ -0,0 +1 @@
+Clean up some of the federation event authentication code for clarity.
diff --git a/changelog.d/10642.misc b/changelog.d/10642.misc
new file mode 100644
index 0000000000..cca1eb6c57
--- /dev/null
+++ b/changelog.d/10642.misc
@@ -0,0 +1 @@
+Convert room member storage tuples to `attrs` classes.
diff --git a/changelog.d/10644.bugfix b/changelog.d/10644.bugfix
new file mode 100644
index 0000000000..d88a81fd82
--- /dev/null
+++ b/changelog.d/10644.bugfix
@@ -0,0 +1 @@
+Rooms with unsupported room versions are no longer returned via `/sync`.
diff --git a/changelog.d/10651.misc b/changelog.d/10651.misc
new file mode 100644
index 0000000000..7104c121e0
--- /dev/null
+++ b/changelog.d/10651.misc
@@ -0,0 +1 @@
+Run a nightly CI build against Twisted trunk.
diff --git a/changelog.d/10654.bugfix b/changelog.d/10654.bugfix
new file mode 100644
index 0000000000..b0bd78453f
--- /dev/null
+++ b/changelog.d/10654.bugfix
@@ -0,0 +1 @@
+Enforce the maximum length for per-room display names and avatar URLs.
\ No newline at end of file
diff --git a/changelog.d/10662.misc b/changelog.d/10662.misc
new file mode 100644
index 0000000000..593f9ceaad
--- /dev/null
+++ b/changelog.d/10662.misc
@@ -0,0 +1 @@
+Do not print out stack traces for network errors when fetching data over federation.
diff --git a/changelog.d/10664.misc b/changelog.d/10664.misc
new file mode 100644
index 0000000000..cebd5e9a96
--- /dev/null
+++ b/changelog.d/10664.misc
@@ -0,0 +1 @@
+Simplify tests for device admin rest API.
\ No newline at end of file
diff --git a/changelog.d/10665.misc b/changelog.d/10665.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10665.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10666.misc b/changelog.d/10666.misc
new file mode 100644
index 0000000000..39a37b90b1
--- /dev/null
+++ b/changelog.d/10666.misc
@@ -0,0 +1 @@
+Add missing type hints to REST servlets.
diff --git a/changelog.d/10667.misc b/changelog.d/10667.misc
new file mode 100644
index 0000000000..c92846ae26
--- /dev/null
+++ b/changelog.d/10667.misc
@@ -0,0 +1 @@
+Flatten the `tests.synapse.rests` package by moving the contents of `v1` and `v2_alpha` into the parent.
\ No newline at end of file
diff --git a/changelog.d/10672.misc b/changelog.d/10672.misc
new file mode 100644
index 0000000000..7104c121e0
--- /dev/null
+++ b/changelog.d/10672.misc
@@ -0,0 +1 @@
+Run a nightly CI build against Twisted trunk.
diff --git a/changelog.d/10677.bugfix b/changelog.d/10677.bugfix
new file mode 100644
index 0000000000..9964afaaee
--- /dev/null
+++ b/changelog.d/10677.bugfix
@@ -0,0 +1 @@
+Fix a bug which caused the `synapse_user_logins_total` Prometheus metric not to be correctly initialised on restart.
diff --git a/changelog.d/8830.removal b/changelog.d/8830.removal
new file mode 100644
index 0000000000..b3a93a9af2
--- /dev/null
+++ b/changelog.d/8830.removal
@@ -0,0 +1 @@
+Remove deprecated Shutdown Room and Purge Room Admin API.
\ No newline at end of file
diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md
index 56e0141c2b..4fcd2b7852 100644
--- a/docs/SUMMARY.md
+++ b/docs/SUMMARY.md
@@ -52,12 +52,11 @@
       - [Event Reports](admin_api/event_reports.md)
       - [Media](admin_api/media_admin_api.md)
       - [Purge History](admin_api/purge_history_api.md)
-      - [Purge Rooms](admin_api/purge_room.md)
       - [Register Users](admin_api/register_api.md)
+      - [Registration Tokens](usage/administration/admin_api/registration_tokens.md)
       - [Manipulate Room Membership](admin_api/room_membership.md)
       - [Rooms](admin_api/rooms.md)
       - [Server Notices](admin_api/server_notices.md)
-      - [Shutdown Room](admin_api/shutdown_room.md)
       - [Statistics](admin_api/statistics.md)
       - [Users](admin_api/user_admin_api.md)
       - [Server Version](admin_api/version_api.md)
diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md
deleted file mode 100644
index 54fea2db6d..0000000000
--- a/docs/admin_api/purge_room.md
+++ /dev/null
@@ -1,21 +0,0 @@
-Deprecated: Purge room API
-==========================
-
-**The old Purge room API is deprecated and will be removed in a future release.
-See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
-
-This API will remove all trace of a room from your database.
-
-All local users must have left the room before it can be removed.
-
-The API is:
-
-```
-POST /_synapse/admin/v1/purge_room
-
-{
-    "room_id": "!room:id"
-}
-```
-
-You must authenticate using the access token of an admin user.
diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
deleted file mode 100644
index 856a629487..0000000000
--- a/docs/admin_api/shutdown_room.md
+++ /dev/null
@@ -1,102 +0,0 @@
-# Deprecated: Shutdown room API
-
-**The old Shutdown room API is deprecated and will be removed in a future release.
-See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
-
-Shuts down a room, preventing new joins and moves local users and room aliases automatically
-to a new room. The new room will be created with the user specified by the
-`new_room_user_id` parameter as room administrator and will contain a message
-explaining what happened. Users invited to the new room will have power level
--10 by default, and thus be unable to speak. The old room's power levels will be changed to
-disallow any further invites or joins.
-
-The local server will only have the power to move local user and room aliases to
-the new room. Users on other servers will be unaffected.
-
-## API
-
-You will need to authenticate with an access token for an admin user.
-
-### URL
-
-`POST /_synapse/admin/v1/shutdown_room/{room_id}`
-
-### URL Parameters
-
-* `room_id` - The ID of the room (e.g `!someroom:example.com`)
-
-### JSON Body Parameters
-
-* `new_room_user_id` - Required. A string representing the user ID of the user that will admin
-                       the new room that all users in the old room will be moved to.
-* `room_name` - Optional. A string representing the name of the room that new users will be
-                invited to.
-* `message` - Optional. A string containing the first message that will be sent as
-              `new_room_user_id` in the new room. Ideally this will clearly convey why the
-               original room was shut down.
-
-If not specified, the default value of `room_name` is "Content Violation
-Notification". The default value of `message` is "Sharing illegal content on
-othis server is not permitted and rooms in violation will be blocked."
-
-### Response Parameters
-
-* `kicked_users` - An integer number representing the number of users that
-                   were kicked.
-* `failed_to_kick_users` - An integer number representing the number of users
-                           that were not kicked.
-* `local_aliases` - An array of strings representing the local aliases that were migrated from
-                    the old room to the new.
-* `new_room_id` - A string representing the room ID of the new room.
-
-## Example
-
-Request:
-
-```
-POST /_synapse/admin/v1/shutdown_room/!somebadroom%3Aexample.com
-
-{
-    "new_room_user_id": "@someuser:example.com",
-    "room_name": "Content Violation Notification",
-    "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service."
-}
-```
-
-Response:
-
-```
-{
-    "kicked_users": 5,
-    "failed_to_kick_users": 0,
-    "local_aliases": ["#badroom:example.com", "#evilsaloon:example.com],
-    "new_room_id": "!newroomid:example.com",
-},
-```
-
-## Undoing room shutdowns
-
-*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level,
-the structure can and does change without notice.
-
-First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
-never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible
-to recover at all:
-
-* If the room was invite-only, your users will need to be re-invited.
-* If the room no longer has any members at all, it'll be impossible to rejoin.
-* The first user to rejoin will have to do so via an alias on a different server.
-
-With all that being said, if you still want to try and recover the room:
-
-1. For safety reasons, shut down Synapse.
-2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
-   * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
-   * The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
-3. Restart Synapse.
-
-You will have to manually handle, if you so choose, the following:
-
-* Aliases that would have been redirected to the Content Violation room.
-* Users that would have been booted from the room (and will have been force-joined to the Content Violation room).
-* Removal of the Content Violation room if desired.
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index 6a9335d6ec..60dc913915 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -21,11 +21,15 @@ It returns a JSON body like the following:
     "threepids": [
         {
             "medium": "email",
-            "address": "<user_mail_1>"
+            "address": "<user_mail_1>",
+            "added_at": 1586458409743,
+            "validated_at": 1586458409743
         },
         {
             "medium": "email",
-            "address": "<user_mail_2>"
+            "address": "<user_mail_2>",
+            "added_at": 1586458409743,
+            "validated_at": 1586458409743
         }
     ],
     "avatar_url": "<avatar_url>",
diff --git a/docs/modules.md b/docs/modules.md
index 9a430390a4..ae8d6f5b73 100644
--- a/docs/modules.md
+++ b/docs/modules.md
@@ -282,6 +282,52 @@ the request is a server admin.
 Modules can modify the `request_content` (by e.g. adding events to its `initial_state`),
 or deny the room's creation by raising a `module_api.errors.SynapseError`.
 
+#### Presence router callbacks
+
+Presence router callbacks allow module developers to specify additional users (local or remote)
+to receive certain presence updates from local users. Presence router callbacks can be 
+registered using the module API's `register_presence_router_callbacks` method.
+
+The available presence router callbacks are:
+
+```python 
+async def get_users_for_states(
+    self,
+    state_updates: Iterable["synapse.api.UserPresenceState"],
+) -> Dict[str, Set["synapse.api.UserPresenceState"]]:
+```
+**Requires** `get_interested_users` to also be registered
+
+Called when processing updates to the presence state of one or more users. This callback can
+be used to instruct the server to forward that presence state to specific users. The module
+must return a dictionary that maps from Matrix user IDs (which can be local or remote) to the
+`UserPresenceState` changes that they should be forwarded.
+
+Synapse will then attempt to send the specified presence updates to each user when possible.
+
+```python
+async def get_interested_users(
+        self, 
+        user_id: str
+) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"]
+```
+**Requires** `get_users_for_states` to also be registered
+
+Called when determining which users someone should be able to see the presence state of. This
+callback should return complementary results to `get_users_for_state` or the presence information 
+may not be properly forwarded.
+
+The callback is given the Matrix user ID for a local user that is requesting presence data and
+should return the Matrix user IDs of the users whose presence state they are allowed to
+query. The returned users can be local or remote. 
+
+Alternatively the callback can return `synapse.module_api.PRESENCE_ALL_USERS`
+to indicate that the user should receive updates from all known users.
+
+For example, if the user `@alice:example.org` is passed to this method, and the Set 
+`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice 
+should receive presence updates sent by Bob and Charlie, regardless of whether these users 
+share a room.
 
 ### Porting an existing module that uses the old interface
 
diff --git a/docs/openid.md b/docs/openid.md
index f685fd551a..49180eec52 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -79,7 +79,7 @@ oidc_providers:
         display_name_template: "{{ user.name }}"
 ```
 
-### [Dex][dex-idp]
+### Dex
 
 [Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider.
 Although it is designed to help building a full-blown provider with an
@@ -117,7 +117,7 @@ oidc_providers:
         localpart_template: "{{ user.name }}"
         display_name_template: "{{ user.name|capitalize }}"
 ```
-### [Keycloak][keycloak-idp]
+### Keycloak
 
 [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
 
@@ -166,7 +166,9 @@ oidc_providers:
         localpart_template: "{{ user.preferred_username }}"
         display_name_template: "{{ user.name }}"
 ```
-### [Auth0][auth0]
+### Auth0
+
+[Auth0][auth0] is a hosted SaaS IdP solution.
 
 1. Create a regular web application for Synapse
 2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback`
@@ -209,7 +211,7 @@ oidc_providers:
 
 ### GitHub
 
-GitHub is a bit special as it is not an OpenID Connect compliant provider, but
+[GitHub][github-idp] is a bit special as it is not an OpenID Connect compliant provider, but
 just a regular OAuth2 provider.
 
 The [`/user` API endpoint](https://developer.github.com/v3/users/#get-the-authenticated-user)
@@ -242,11 +244,13 @@ oidc_providers:
         display_name_template: "{{ user.name }}"
 ```
 
-### [Google][google-idp]
+### Google
+
+[Google][google-idp] is an OpenID certified authentication and authorisation provider.
 
 1. Set up a project in the Google API Console (see
    https://developers.google.com/identity/protocols/oauth2/openid-connect#appsetup).
-2. add an "OAuth Client ID" for a Web Application under "Credentials".
+2. Add an "OAuth Client ID" for a Web Application under "Credentials".
 3. Copy the Client ID and Client Secret, and add the following to your synapse config:
    ```yaml
    oidc_providers:
@@ -446,3 +450,51 @@ The synapse config will look like this:
       config:
         email_template: "{{ user.email }}"
 ```
+
+## Django OAuth Toolkit
+
+[django-oauth-toolkit](https://github.com/jazzband/django-oauth-toolkit) is a
+Django application providing out of the box all the endpoints, data and logic
+needed to add OAuth2 capabilities to your Django projects. It supports
+[OpenID Connect too](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html).
+
+Configuration on Django's side:
+
+1. Add an application: https://example.com/admin/oauth2_provider/application/add/ and choose parameters like this:
+* `Redirect uris`: https://synapse.example.com/_synapse/client/oidc/callback
+* `Client type`: `Confidential`
+* `Authorization grant type`: `Authorization code`
+* `Algorithm`: `HMAC with SHA-2 256`
+2. You can [customize the claims](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html#customizing-the-oidc-responses) Django gives to synapse (optional):
+   <details>
+    <summary>Code sample</summary>
+
+    ```python
+    class CustomOAuth2Validator(OAuth2Validator):
+
+        def get_additional_claims(self, request):
+            return {
+                "sub": request.user.email,
+                "email": request.user.email,
+                "first_name": request.user.first_name,
+                "last_name": request.user.last_name,
+            }
+    ```
+   </details>
+Your synapse config is then:
+
+```yaml
+oidc_providers:
+  - idp_id: django_example
+    idp_name: "Django Example"
+    issuer: "https://example.com/o/"
+    client_id: "your-client-id"  # CHANGE ME
+    client_secret: "your-client-secret"  # CHANGE ME
+    scopes: ["openid"]
+    user_profile_method: "userinfo_endpoint"  # needed because oauth-toolkit does not include user information in the authorization response
+    user_mapping_provider:
+      config:
+        localpart_template: "{{ user.email.split('@')[0] }}"
+        display_name_template: "{{ user.first_name }} {{ user.last_name }}"
+        email_template: "{{ user.email }}"
+```
diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md
index 4a3e720240..face54fe2b 100644
--- a/docs/presence_router_module.md
+++ b/docs/presence_router_module.md
@@ -1,3 +1,9 @@
+<h2 style="color:red">
+This page of the Synapse documentation is now deprecated. For up to date
+documentation on setting up or writing a presence router module, please see
+<a href="modules.md">this page</a>.
+</h2>
+
 # Presence Router Module
 
 Synapse supports configuring a module that can specify additional users
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 3ec76d5abf..935841dbfa 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -108,20 +108,6 @@ presence:
   #
   #enabled: false
 
-  # Presence routers are third-party modules that can specify additional logic
-  # to where presence updates from users are routed.
-  #
-  presence_router:
-    # The custom module's class. Uncomment to use a custom presence router module.
-    #
-    #module: "my_custom_router.PresenceRouter"
-
-    # Configuration options of the custom module. Refer to your module's
-    # documentation for available options.
-    #
-    #config:
-    #  example_option: 'something'
-
 # Whether to require authentication to retrieve profile data (avatars,
 # display names) of other users through the client API. Defaults to
 # 'false'. Note that profile data is also available via the federation
@@ -807,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config"
 #     is using
 #   - one for registration that ratelimits registration requests based on the
 #     client's IP address.
+#   - one for checking the validity of registration tokens that ratelimits
+#     requests based on the client's IP address.
 #   - one for login that ratelimits login requests based on the client's IP
 #     address.
 #   - one for login that ratelimits login requests based on the account the
@@ -835,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
 #  per_second: 0.17
 #  burst_count: 3
 #
+#rc_registration_token_validity:
+#  per_second: 0.1
+#  burst_count: 5
+#
 #rc_login:
 #  address:
 #    per_second: 0.17
@@ -1183,6 +1175,15 @@ url_preview_accept_language:
 #
 #enable_3pid_lookup: true
 
+# Require users to submit a token during registration.
+# Tokens can be managed using the admin API:
+# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+# Note that `enable_registration` must be set to `true`.
+# Disabling this option will not delete any tokens previously generated.
+# Defaults to false. Uncomment the following to require tokens:
+#
+#registration_requires_token: true
+
 # If set, allows registration of standard or admin accounts by anyone who
 # has the shared secret, even if registration is otherwise disabled.
 #
diff --git a/docs/upgrade.md b/docs/upgrade.md
index e5d386b02f..6d4b8cb48e 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -85,6 +85,28 @@ process, for example:
     dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
     ```
 
+# Upgrading to v1.xx.0
+
+## Removal of old Room Admin API
+
+The following admin APIs were deprecated in [Synapse 1.25](https://github.com/matrix-org/synapse/blob/v1.25.0/CHANGES.md#removal-warning)
+(released on 2021-01-13) and have now been removed:
+
+-   `POST /_synapse/admin/v1/purge_room`
+-   `POST /_synapse/admin/v1/shutdown_room/<room_id>`
+
+Any scripts still using the above APIs should be converted to use the
+[Delete Room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api).
+
+## User-interactive authentication fallback templates can now display errors
+
+This may affect you if you make use of custom HTML templates for the
+[reCAPTCHA](../synapse/res/templates/recaptcha.html) or
+[terms](../synapse/res/templates/terms.html) fallback pages.
+
+The template is now provided an `error` variable if the authentication
+process failed. See the default templates linked above for an example.
+
 
 # Upgrading to v1.41.0
 
diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md
new file mode 100644
index 0000000000..828c0277d6
--- /dev/null
+++ b/docs/usage/administration/admin_api/registration_tokens.md
@@ -0,0 +1,295 @@
+# Registration Tokens
+
+This API allows you to manage tokens which can be used to authenticate
+registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md).
+To use it, you will need to enable the `registration_requires_token` config
+option, and authenticate by providing an `access_token` for a server admin:
+see [Admin API](../../usage/administration/admin_api).
+Note that this API is still experimental; not all clients may support it yet.
+
+
+## Registration token objects
+
+Most endpoints make use of JSON objects that contain details about tokens.
+These objects have the following fields:
+- `token`: The token which can be used to authenticate registration.
+- `uses_allowed`: The number of times the token can be used to complete a
+  registration before it becomes invalid.
+- `pending`: The number of pending uses the token has. When someone uses
+  the token to authenticate themselves, the pending counter is incremented
+  so that the token is not used more than the permitted number of times.
+  When the person completes registration the pending counter is decremented,
+  and the completed counter is incremented.
+- `completed`: The number of times the token has been used to successfully
+  complete a registration.
+- `expiry_time`: The latest time the token is valid. Given as the number of
+  milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+  To convert this into a human-readable form you can remove the milliseconds
+  and use the `date` command. For example, `date -d '@1625394937'`.
+
+
+## List all tokens
+
+Lists all tokens and details about them. If the request is successful, the top
+level JSON object will have a `registration_tokens` key which is an array of
+registration token objects.
+
+```
+GET /_synapse/admin/v1/registration_tokens
+```
+
+Optional query parameters:
+- `valid`: `true` or `false`. If `true`, only valid tokens are returned.
+  If `false`, only tokens that have expired or have had all uses exhausted are
+  returned. If omitted, all tokens are returned regardless of validity.
+
+Example:
+
+```
+GET /_synapse/admin/v1/registration_tokens
+```
+```
+200 OK
+
+{
+    "registration_tokens": [
+        {
+            "token": "abcd",
+            "uses_allowed": 3,
+            "pending": 0,
+            "completed": 1,
+            "expiry_time": null
+        },
+        {
+            "token": "pqrs",
+            "uses_allowed": 2,
+            "pending": 1,
+            "completed": 1,
+            "expiry_time": null
+        },
+        {
+            "token": "wxyz",
+            "uses_allowed": null,
+            "pending": 0,
+            "completed": 9,
+            "expiry_time": 1625394937000    // 2021-07-04 10:35:37 UTC
+        }
+    ]
+}
+```
+
+Example using the `valid` query parameter:
+
+```
+GET /_synapse/admin/v1/registration_tokens?valid=false
+```
+```
+200 OK
+
+{
+    "registration_tokens": [
+        {
+            "token": "pqrs",
+            "uses_allowed": 2,
+            "pending": 1,
+            "completed": 1,
+            "expiry_time": null
+        },
+        {
+            "token": "wxyz",
+            "uses_allowed": null,
+            "pending": 0,
+            "completed": 9,
+            "expiry_time": 1625394937000    // 2021-07-04 10:35:37 UTC
+        }
+    ]
+}
+```
+
+
+## Get one token
+
+Get details about a single token. If the request is successful, the response
+body will be a registration token object.
+
+```
+GET /_synapse/admin/v1/registration_tokens/<token>
+```
+
+Path parameters:
+- `token`: The registration token to return details of.
+
+Example:
+
+```
+GET /_synapse/admin/v1/registration_tokens/abcd
+```
+```
+200 OK
+
+{
+    "token": "abcd",
+    "uses_allowed": 3,
+    "pending": 0,
+    "completed": 1,
+    "expiry_time": null
+}
+```
+
+
+## Create token
+
+Create a new registration token. If the request is successful, the newly created
+token will be returned as a registration token object in the response body.
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+```
+
+The request body must be a JSON object and can contain the following fields:
+- `token`: The registration token. A string of no more than 64 characters that
+  consists only of characters matched by the regex `[A-Za-z0-9-_]`.
+  Default: randomly generated.
+- `uses_allowed`: The integer number of times the token can be used to complete
+  a registration before it becomes invalid.
+  Default: `null` (unlimited uses).
+- `expiry_time`: The latest time the token is valid. Given as the number of
+  milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+  You could use, for example, `date '+%s000' -d 'tomorrow'`.
+  Default: `null` (token does not expire).
+- `length`: The length of the token randomly generated if `token` is not
+  specified. Must be between 1 and 64 inclusive. Default: `16`.
+
+If a field is omitted the default is used.
+
+Example using defaults:
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+
+{}
+```
+```
+200 OK
+
+{
+    "token": "0M-9jbkf2t_Tgiw1",
+    "uses_allowed": null,
+    "pending": 0,
+    "completed": 0,
+    "expiry_time": null
+}
+```
+
+Example specifying some fields:
+
+```
+POST /_synapse/admin/v1/registration_tokens/new
+
+{
+    "token": "defg",
+    "uses_allowed": 1
+}
+```
+```
+200 OK
+
+{
+    "token": "defg",
+    "uses_allowed": 1,
+    "pending": 0,
+    "completed": 0,
+    "expiry_time": null
+}
+```
+
+
+## Update token
+
+Update the number of allowed uses or expiry time of a token. If the request is
+successful, the updated token will be returned as a registration token object
+in the response body.
+
+```
+PUT /_synapse/admin/v1/registration_tokens/<token>
+```
+
+Path parameters:
+- `token`: The registration token to update.
+
+The request body must be a JSON object and can contain the following fields:
+- `uses_allowed`: The integer number of times the token can be used to complete
+  a registration before it becomes invalid. By setting `uses_allowed` to `0`
+  the token can be easily made invalid without deleting it.
+  If `null` the token will have an unlimited number of uses.
+- `expiry_time`: The latest time the token is valid. Given as the number of
+  milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch).
+  If `null` the token will not expire.
+
+If a field is omitted its value is not modified.
+
+Example:
+
+```
+PUT /_synapse/admin/v1/registration_tokens/defg
+
+{
+    "expiry_time": 4781243146000    // 2121-07-06 11:05:46 UTC
+}
+```
+```
+200 OK
+
+{
+    "token": "defg",
+    "uses_allowed": 1,
+    "pending": 0,
+    "completed": 0,
+    "expiry_time": 4781243146000
+}
+```
+
+
+## Delete token
+
+Delete a registration token. If the request is successful, the response body
+will be an empty JSON object.
+
+```
+DELETE /_synapse/admin/v1/registration_tokens/<token>
+```
+
+Path parameters:
+- `token`: The registration token to delete.
+
+Example:
+
+```
+DELETE /_synapse/admin/v1/registration_tokens/wxyz
+```
+```
+200 OK
+
+{}
+```
+
+
+## Errors
+
+If a request fails a "standard error response" will be returned as defined in
+the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards).
+
+For example, if the token specified in a path parameter does not exist a
+`404 Not Found` error will be returned.
+
+```
+GET /_synapse/admin/v1/registration_tokens/1234
+```
+```
+404 Not Found
+
+{
+    "errcode": "M_NOT_FOUND",
+    "error": "No such registration token: 1234"
+}
+```
diff --git a/docs/workers.md b/docs/workers.md
index 2e63f03452..3121241894 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -236,6 +236,7 @@ expressions:
     # Registration/login requests
     ^/_matrix/client/(api/v1|r0|unstable)/login$
     ^/_matrix/client/(r0|unstable)/register$
+    ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$
 
     # Event sending requests
     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
diff --git a/mypy.ini b/mypy.ini
index e1b9405daa..745e6b78eb 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -28,10 +28,13 @@ files =
   synapse/federation,
   synapse/groups,
   synapse/handlers,
+  synapse/http/additional_resource.py,
   synapse/http/client.py,
   synapse/http/federation/matrix_federation_agent.py,
+  synapse/http/federation/srv_resolver.py,
   synapse/http/federation/well_known_resolver.py,
   synapse/http/matrixfederationclient.py,
+  synapse/http/proxyagent.py,
   synapse/http/servlet.py,
   synapse/http/server.py,
   synapse/http/site.py,
@@ -54,6 +57,7 @@ files =
   synapse/storage/databases/main/keys.py,
   synapse/storage/databases/main/pusher.py,
   synapse/storage/databases/main/registration.py,
+  synapse/storage/databases/main/session.py,
   synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,
   synapse/storage/database.py,
@@ -87,8 +91,9 @@ files =
   tests/test_utils,
   tests/handlers/test_password_providers.py,
   tests/handlers/test_room_summary.py,
-  tests/rest/client/v1/test_login.py,
-  tests/rest/client/v2_alpha/test_auth.py,
+  tests/handlers/test_sync.py,
+  tests/rest/client/test_login.py,
+  tests/rest/client/test_auth.py,
   tests/util/test_itertools.py,
   tests/util/test_stream_change_cache.py
 
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e0e24fddac..829061c870 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -79,6 +79,7 @@ class LoginType:
     TERMS = "m.login.terms"
     SSO = "m.login.sso"
     DUMMY = "m.login.dummy"
+    REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
 
 
 # This is used in the `type` parameter for /register when called by
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index dc662bca83..9480f448d7 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -147,6 +147,14 @@ class SynapseError(CodeMessageException):
         return cs_error(self.msg, self.errcode)
 
 
+class InvalidAPICallError(SynapseError):
+    """You called an existing API endpoint, but fed that endpoint
+    invalid or incomplete data."""
+
+    def __init__(self, msg: str):
+        super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON)
+
+
 class ProxiedRequestError(SynapseError):
     """An error from a general matrix endpoint, eg. from a proxied Matrix API call.
 
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 50a02f51f5..39e28aff9f 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -37,6 +37,7 @@ from synapse.app import check_bind_error
 from synapse.app.phone_stats_home import start_phone_stats_home
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
+from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.logging.context import PreserveLoggingContext
@@ -370,6 +371,7 @@ async def start(hs: "HomeServer"):
 
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
+    load_legacy_presence_router(hs)
 
     # If we've configured an expiry time for caches, start the background job now.
     setup_expire_lru_cache_entries(hs)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 845e6a8220..9b71dd75e6 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -95,7 +95,10 @@ from synapse.rest.client.profile import (
     ProfileRestServlet,
 )
 from synapse.rest.client.push_rule import PushRuleRestServlet
-from synapse.rest.client.register import RegisterRestServlet
+from synapse.rest.client.register import (
+    RegisterRestServlet,
+    RegistrationTokenValidityRestServlet,
+)
 from synapse.rest.client.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.client.voip import VoipRestServlet
@@ -115,6 +118,7 @@ from synapse.storage.databases.main.monthly_active_users import (
 from synapse.storage.databases.main.presence import PresenceStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.databases.main.session import SessionStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.databases.main.transactions import TransactionWorkerStore
 from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
@@ -250,6 +254,7 @@ class GenericWorkerSlavedStore(
     SearchStore,
     TransactionWorkerStore,
     LockStore,
+    SessionStore,
     BaseSlavedStore,
 ):
     pass
@@ -279,6 +284,7 @@ class GenericWorkerServer(HomeServer):
                     resource = JsonResource(self, canonical_json=False)
 
                     RegisterRestServlet(self).register(resource)
+                    RegistrationTokenValidityRestServlet(self).register(resource)
                     login.register_servlets(self, resource)
                     ThreepidRestServlet(self).register(resource)
                     DevicesRestServlet(self).register(resource)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 907df9591a..95deda11a5 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -39,5 +39,8 @@ class ExperimentalConfig(Config):
         # MSC3244 (room version capabilities)
         self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
 
+        # MSC3283 (set displayname, avatar_url and change 3pid capabilities)
+        self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
+
         # MSC3266 (room summary api)
         self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 7a8d5851c4..f856327bd8 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -79,6 +79,11 @@ class RatelimitConfig(Config):
 
         self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
 
+        self.rc_registration_token_validity = RateLimitConfig(
+            config.get("rc_registration_token_validity", {}),
+            defaults={"per_second": 0.1, "burst_count": 5},
+        )
+
         rc_login_config = config.get("rc_login", {})
         self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
         self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
@@ -143,6 +148,8 @@ class RatelimitConfig(Config):
         #     is using
         #   - one for registration that ratelimits registration requests based on the
         #     client's IP address.
+        #   - one for checking the validity of registration tokens that ratelimits
+        #     requests based on the client's IP address.
         #   - one for login that ratelimits login requests based on the client's IP
         #     address.
         #   - one for login that ratelimits login requests based on the account the
@@ -171,6 +178,10 @@ class RatelimitConfig(Config):
         #  per_second: 0.17
         #  burst_count: 3
         #
+        #rc_registration_token_validity:
+        #  per_second: 0.1
+        #  burst_count: 5
+        #
         #rc_login:
         #  address:
         #    per_second: 0.17
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0ad919b139..7cffdacfa5 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,6 +33,9 @@ class RegistrationConfig(Config):
         self.registrations_require_3pid = config.get("registrations_require_3pid", [])
         self.allowed_local_3pids = config.get("allowed_local_3pids", [])
         self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
+        self.registration_requires_token = config.get(
+            "registration_requires_token", False
+        )
         self.registration_shared_secret = config.get("registration_shared_secret")
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -140,6 +143,9 @@ class RegistrationConfig(Config):
                 "mechanism by removing the `access_token_lifetime` option."
             )
 
+        # The fallback template used for authenticating using a registration token
+        self.registration_token_template = self.read_template("registration_token.html")
+
         # The success template used during fallback auth.
         self.fallback_success_template = self.read_template("auth_success.html")
 
@@ -199,6 +205,15 @@ class RegistrationConfig(Config):
         #
         #enable_3pid_lookup: true
 
+        # Require users to submit a token during registration.
+        # Tokens can be managed using the admin API:
+        # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html
+        # Note that `enable_registration` must be set to `true`.
+        # Disabling this option will not delete any tokens previously generated.
+        # Defaults to false. Uncomment the following to require tokens:
+        #
+        #registration_requires_token: true
+
         # If set, allows registration of standard or admin accounts by anyone who
         # has the shared secret, even if registration is otherwise disabled.
         #
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 8494795919..d2c900f50c 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -248,6 +248,7 @@ class ServerConfig(Config):
             self.use_presence = config.get("use_presence", True)
 
         # Custom presence router module
+        # This is the legacy way of configuring it (the config should now be put in the modules section)
         self.presence_router_module_class = None
         self.presence_router_config = None
         presence_router_config = presence_config.get("presence_router")
@@ -870,20 +871,6 @@ class ServerConfig(Config):
           #
           #enabled: false
 
-          # Presence routers are third-party modules that can specify additional logic
-          # to where presence updates from users are routed.
-          #
-          presence_router:
-            # The custom module's class. Uncomment to use a custom presence router module.
-            #
-            #module: "my_custom_router.PresenceRouter"
-
-            # Configuration options of the custom module. Refer to your module's
-            # documentation for available options.
-            #
-            #config:
-            #  example_option: 'something'
-
         # Whether to require authentication to retrieve profile data (avatars,
         # display names) of other users through the client API. Defaults to
         # 'false'. Note that profile data is also available via the federation
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
index 6c37c8a7a4..eb4556cdc1 100644
--- a/synapse/events/presence_router.py
+++ b/synapse/events/presence_router.py
@@ -11,45 +11,115 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
+import logging
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Union,
+)
 
 from synapse.api.presence import UserPresenceState
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+GET_USERS_FOR_STATES_CALLBACK = Callable[
+    [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
+]
+GET_INTERESTED_USERS_CALLBACK = Callable[
+    [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
+]
+
+logger = logging.getLogger(__name__)
+
+
+def load_legacy_presence_router(hs: "HomeServer"):
+    """Wrapper that loads a presence router module configured using the old
+    configuration, and registers the hooks they implement.
+    """
+
+    if hs.config.presence_router_module_class is None:
+        return
+
+    module = hs.config.presence_router_module_class
+    config = hs.config.presence_router_config
+    api = hs.get_module_api()
+
+    presence_router = module(config=config, module_api=api)
+
+    # The known hooks. If a module implements a method which name appears in this set,
+    # we'll want to register it.
+    presence_router_methods = {
+        "get_users_for_states",
+        "get_interested_users",
+    }
+
+    # All methods that the module provides should be async, but this wasn't enforced
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        def run(*args, **kwargs):
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
+
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks = {
+        hook: async_wrapper(getattr(presence_router, hook, None))
+        for hook in presence_router_methods
+    }
+
+    api.register_presence_router_callbacks(**hooks)
+
 
 class PresenceRouter:
     """
     A module that the homeserver will call upon to help route user presence updates to
-    additional destinations. If a custom presence router is configured, calls will be
-    passed to that instead.
+    additional destinations.
     """
 
     ALL_USERS = "ALL"
 
     def __init__(self, hs: "HomeServer"):
-        self.custom_presence_router = None
+        # Initially there are no callbacks
+        self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
+        self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
 
-        # Check whether a custom presence router module has been configured
-        if hs.config.presence_router_module_class:
-            # Initialise the module
-            self.custom_presence_router = hs.config.presence_router_module_class(
-                config=hs.config.presence_router_config, module_api=hs.get_module_api()
+    def register_presence_router_callbacks(
+        self,
+        get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
+        get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
+    ):
+        # PresenceRouter modules are required to implement both of these methods
+        # or neither of them as they are assumed to act in a complementary manner
+        paired_methods = [get_users_for_states, get_interested_users]
+        if paired_methods.count(None) == 1:
+            raise RuntimeError(
+                "PresenceRouter modules must register neither or both of the paired callbacks: "
+                "[get_users_for_states, get_interested_users]"
             )
 
-            # Ensure the module has implemented the required methods
-            required_methods = ["get_users_for_states", "get_interested_users"]
-            for method_name in required_methods:
-                if not hasattr(self.custom_presence_router, method_name):
-                    raise Exception(
-                        "PresenceRouter module '%s' must implement all required methods: %s"
-                        % (
-                            hs.config.presence_router_module_class.__name__,
-                            ", ".join(required_methods),
-                        )
-                    )
+        # Append the methods provided to the lists of callbacks
+        if get_users_for_states is not None:
+            self._get_users_for_states_callbacks.append(get_users_for_states)
+
+        if get_interested_users is not None:
+            self._get_interested_users_callbacks.append(get_interested_users)
 
     async def get_users_for_states(
         self,
@@ -66,14 +136,40 @@ class PresenceRouter:
           A dictionary of user_id -> set of UserPresenceState, indicating which
           presence updates each user should receive.
         """
-        if self.custom_presence_router is not None:
-            # Ask the custom module
-            return await self.custom_presence_router.get_users_for_states(
-                state_updates=state_updates
-            )
 
-        # Don't include any extra destinations for presence updates
-        return {}
+        # Bail out early if we don't have any callbacks to run.
+        if len(self._get_users_for_states_callbacks) == 0:
+            # Don't include any extra destinations for presence updates
+            return {}
+
+        users_for_states = {}
+        # run all the callbacks for get_users_for_states and combine the results
+        for callback in self._get_users_for_states_callbacks:
+            try:
+                result = await callback(state_updates)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
+
+            if not isinstance(result, Dict):
+                logger.warning(
+                    "Wrong type returned by module API callback %s: %s, expected Dict",
+                    callback,
+                    result,
+                )
+                continue
+
+            for key, new_entries in result.items():
+                if not isinstance(new_entries, Set):
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected Set",
+                        callback,
+                        new_entries,
+                    )
+                    break
+                users_for_states.setdefault(key, set()).update(new_entries)
+
+        return users_for_states
 
     async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
         """
@@ -92,12 +188,36 @@ class PresenceRouter:
             A set of user IDs to return presence updates for, or ALL_USERS to return all
             known updates.
         """
-        if self.custom_presence_router is not None:
-            # Ask the custom module for interested users
-            return await self.custom_presence_router.get_interested_users(
-                user_id=user_id
-            )
 
-        # A custom presence router is not defined.
-        # Don't report any additional interested users
-        return set()
+        # Bail out early if we don't have any callbacks to run.
+        if len(self._get_interested_users_callbacks) == 0:
+            # Don't report any additional interested users
+            return set()
+
+        interested_users = set()
+        # run all the callbacks for get_interested_users and combine the results
+        for callback in self._get_interested_users_callbacks:
+            try:
+                result = await callback(user_id)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
+
+            # If one of the callbacks returns ALL_USERS then we can stop calling all
+            # of the other callbacks, since the set of interested_users is already as
+            # large as it can possibly be
+            if result == PresenceRouter.ALL_USERS:
+                return PresenceRouter.ALL_USERS
+
+            if not isinstance(result, Set):
+                logger.warning(
+                    "Wrong type returned by module API callback %s: %s, expected set",
+                    callback,
+                    result,
+                )
+                continue
+
+            # Add the new interested users to the set
+            interested_users.update(result)
+
+        return interested_users
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 29979414e3..44d9e8a5c7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
     Codes,
     FederationDeniedError,
     HttpResponseException,
+    RequestSendFailed,
     SynapseError,
     UnsupportedRoomVersionError,
 )
@@ -558,7 +559,11 @@ class FederationClient(FederationBase):
 
             try:
                 return await callback(destination)
-            except InvalidResponseError as e:
+            except (
+                RequestSendFailed,
+                InvalidResponseError,
+                NotRetryingDestination,
+            ) as e:
                 logger.warning("Failed to %s via %s: %s", description, destination, e)
             except UnsupportedRoomVersionError:
                 raise
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index afd8f8580a..e1b58d40c5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1005,9 +1005,7 @@ class FederationServer(FederationBase):
             async with lock:
                 logger.info("handling received PDU: %s", event)
                 try:
-                    await self.handler.on_receive_pdu(
-                        origin, event, sent_to_us_directly=True
-                    )
+                    await self.handler.on_receive_pdu(origin, event)
                 except FederationError as e:
                     # XXX: Ideally we'd inform the remote we failed to process
                     # the event, but we can't return an error in the transaction
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 161b3c933c..98d3d2d97f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -627,23 +627,28 @@ class AuthHandler(BaseHandler):
 
     async def add_oob_auth(
         self, stagetype: str, authdict: Dict[str, Any], clientip: str
-    ) -> bool:
+    ) -> None:
         """
         Adds the result of out-of-band authentication into an existing auth
         session. Currently used for adding the result of fallback auth.
+
+        Raises:
+            LoginError if the stagetype is unknown or the session is missing.
+            LoginError is raised by check_auth if authentication fails.
         """
         if stagetype not in self.checkers:
-            raise LoginError(400, "", Codes.MISSING_PARAM)
+            raise LoginError(
+                400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM
+            )
         if "session" not in authdict:
-            raise LoginError(400, "", Codes.MISSING_PARAM)
+            raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM)
 
+        # If authentication fails a LoginError is raised. Otherwise, store
+        # the successful result.
         result = await self.checkers[stagetype].check_auth(authdict, clientip)
-        if result:
-            await self.store.mark_ui_auth_stage_complete(
-                authdict["session"], stagetype, result
-            )
-            return True
-        return False
+        await self.store.mark_ui_auth_stage_complete(
+            authdict["session"], stagetype, result
+        )
 
     def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
         """
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0e13bdaac..246df43501 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -203,18 +203,13 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    async def on_receive_pdu(
-        self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
-    ) -> None:
-        """Process a PDU received via a federation /send/ transaction, or
-        via backfill of missing prev_events
+    async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None:
+        """Process a PDU received via a federation /send/ transaction
 
         Args:
             origin: server which initiated the /send/ transaction. Will
                 be used to fetch missing events or state.
             pdu: received PDU
-            sent_to_us_directly: True if this event was pushed to us; False if
-                we pulled it as the result of a missing prev_event.
         """
 
         room_id = pdu.room_id
@@ -276,8 +271,6 @@ class FederationHandler(BaseHandler):
             )
             return None
 
-        state = None
-
         # Check that the event passes auth based on the state at the event. This is
         # done for events that are to be added to the timeline (non-outliers).
         #
@@ -285,23 +278,16 @@ class FederationHandler(BaseHandler):
         #  - Fetching any missing prev events to fill in gaps in the graph
         #  - Fetching state if we have a hole in the graph
         if not pdu.internal_metadata.is_outlier():
-            # We only backfill backwards to the min depth.
-            min_depth = await self.get_min_depth_for_context(pdu.room_id)
-
-            logger.debug("min_depth: %d", min_depth)
-
             prevs = set(pdu.prev_event_ids())
             seen = await self.store.have_events_in_timeline(prevs)
+            missing_prevs = prevs - seen
+
+            if missing_prevs:
+                # We only backfill backwards to the min depth.
+                min_depth = await self.get_min_depth_for_context(pdu.room_id)
+                logger.debug("min_depth: %d", min_depth)
 
-            if min_depth is not None and pdu.depth < min_depth:
-                # This is so that we don't notify the user about this
-                # message, to work around the fact that some events will
-                # reference really really old events we really don't want to
-                # send to the clients.
-                pdu.internal_metadata.outlier = True
-            elif min_depth is not None and pdu.depth > min_depth:
-                missing_prevs = prevs - seen
-                if sent_to_us_directly and missing_prevs:
+                if min_depth is not None and pdu.depth > min_depth:
                     # If we're missing stuff, ensure we only fetch stuff one
                     # at a time.
                     logger.info(
@@ -325,42 +311,23 @@ class FederationHandler(BaseHandler):
                                 % (event_id, e)
                             ) from e
 
-                        # Update the set of things we've seen after trying to
-                        # fetch the missing stuff
-                        seen = await self.store.have_events_in_timeline(prevs)
-
-                        if not prevs - seen:
-                            logger.info(
-                                "Found all missing prev_events",
-                            )
-
-            missing_prevs = prevs - seen
-            if missing_prevs:
-                # We've still not been able to get all of the prev_events for this event.
-                #
-                # In this case, we need to fall back to asking another server in the
-                # federation for the state at this event. That's ok provided we then
-                # resolve the state against other bits of the DAG before using it (which
-                # will ensure that you can't just take over a room by sending an event,
-                # withholding its prev_events, and declaring yourself to be an admin in
-                # the subsequent state request).
-                #
-                # Now, if we're pulling this event as a missing prev_event, then clearly
-                # this event is not going to become the only forward-extremity and we are
-                # guaranteed to resolve its state against our existing forward
-                # extremities, so that should be fine.
-                #
-                # On the other hand, if this event was pushed to us, it is possible for
-                # it to become the only forward-extremity in the room, and we would then
-                # trust its state to be the state for the whole room. This is very bad.
-                # Further, if the event was pushed to us, there is no excuse for us not to
-                # have all the prev_events. We therefore reject any such events.
-                #
-                # XXX this really feels like it could/should be merged with the above,
-                # but there is an interaction with min_depth that I'm not really
-                # following.
-
-                if sent_to_us_directly:
+                    # Update the set of things we've seen after trying to
+                    # fetch the missing stuff
+                    seen = await self.store.have_events_in_timeline(prevs)
+                    missing_prevs = prevs - seen
+
+                    if not missing_prevs:
+                        logger.info("Found all missing prev_events")
+
+                if missing_prevs:
+                    # since this event was pushed to us, it is possible for it to
+                    # become the only forward-extremity in the room, and we would then
+                    # trust its state to be the state for the whole room. This is very
+                    # bad. Further, if the event was pushed to us, there is no excuse
+                    # for us not to have all the prev_events. (XXX: apart from
+                    # min_depth?)
+                    #
+                    # We therefore reject any such events.
                     logger.warning(
                         "Rejecting: failed to fetch %d prev events: %s",
                         len(missing_prevs),
@@ -376,93 +343,7 @@ class FederationHandler(BaseHandler):
                         affected=pdu.event_id,
                     )
 
-                logger.info(
-                    "Event %s is missing prev_events %s: calculating state for a "
-                    "backwards extremity",
-                    event_id,
-                    shortstr(missing_prevs),
-                )
-
-                # Calculate the state after each of the previous events, and
-                # resolve them to find the correct state at the current event.
-                event_map = {event_id: pdu}
-                try:
-                    # Get the state of the events we know about
-                    ours = await self.state_store.get_state_groups_ids(room_id, seen)
-
-                    # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps: List[StateMap[str]] = list(ours.values())
-
-                    # we don't need this any more, let's delete it.
-                    del ours
-
-                    # Ask the remote server for the states we don't
-                    # know about
-                    for p in missing_prevs:
-                        logger.info("Requesting state after missing prev_event %s", p)
-
-                        with nested_logging_context(p):
-                            # note that if any of the missing prevs share missing state or
-                            # auth events, the requests to fetch those events are deduped
-                            # by the get_pdu_cache in federation_client.
-                            remote_state = (
-                                await self._get_state_after_missing_prev_event(
-                                    origin, room_id, p
-                                )
-                            )
-
-                            remote_state_map = {
-                                (x.type, x.state_key): x.event_id for x in remote_state
-                            }
-                            state_maps.append(remote_state_map)
-
-                            for x in remote_state:
-                                event_map[x.event_id] = x
-
-                    room_version = await self.store.get_room_version_id(room_id)
-                    state_map = (
-                        await self._state_resolution_handler.resolve_events_with_store(
-                            room_id,
-                            room_version,
-                            state_maps,
-                            event_map,
-                            state_res_store=StateResolutionStore(self.store),
-                        )
-                    )
-
-                    # We need to give _process_received_pdu the actual state events
-                    # rather than event ids, so generate that now.
-
-                    # First though we need to fetch all the events that are in
-                    # state_map, so we can build up the state below.
-                    evs = await self.store.get_events(
-                        list(state_map.values()),
-                        get_prev_content=False,
-                        redact_behaviour=EventRedactBehaviour.AS_IS,
-                    )
-                    event_map.update(evs)
-
-                    state = [event_map[e] for e in state_map.values()]
-                except Exception:
-                    logger.warning(
-                        "Error attempting to resolve state at missing " "prev_events",
-                        exc_info=True,
-                    )
-                    raise FederationError(
-                        "ERROR",
-                        403,
-                        "We can't get valid state history.",
-                        affected=event_id,
-                    )
-
-        # A second round of checks for all events. Check that the event passes auth
-        # based on `auth_events`, this allows us to assert that the event would
-        # have been allowed at some point. If an event passes this check its OK
-        # for it to be used as part of a returned `/state` request, as either
-        # a) we received the event as part of the original join and so trust it, or
-        # b) we'll do a state resolution with existing state before it becomes
-        # part of the "current state", which adds more protection.
-        await self._process_received_pdu(origin, pdu, state=state)
+        await self._process_received_pdu(origin, pdu, state=None)
 
     async def _get_missing_events_for_pdu(
         self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
@@ -562,24 +443,7 @@ class FederationHandler(BaseHandler):
             return
 
         logger.info("Got %d prev_events", len(missing_events))
-
-        # We want to sort these by depth so we process them and
-        # tell clients about them in order.
-        missing_events.sort(key=lambda x: x.depth)
-
-        for ev in missing_events:
-            logger.info("Handling received prev_event %s", ev)
-            with nested_logging_context(ev.event_id):
-                try:
-                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
-                except FederationError as e:
-                    if e.code == 403:
-                        logger.warning(
-                            "Received prev_event %s failed history check.",
-                            ev.event_id,
-                        )
-                    else:
-                        raise
+        await self._process_pulled_events(origin, missing_events)
 
     async def _get_state_for_room(
         self,
@@ -1496,6 +1360,198 @@ class FederationHandler(BaseHandler):
                 event_infos,
             )
 
+    async def _process_pulled_events(
+        self, origin: str, events: Iterable[EventBase]
+    ) -> None:
+        """Process a batch of events we have pulled from a remote server
+
+        Pulls in any events required to auth the events, persists the received events,
+        and notifies clients, if appropriate.
+
+        Assumes the events have already had their signatures and hashes checked.
+
+        Params:
+            origin: The server we received these events from
+            events: The received events.
+        """
+
+        # We want to sort these by depth so we process them and
+        # tell clients about them in order.
+        sorted_events = sorted(events, key=lambda x: x.depth)
+
+        for ev in sorted_events:
+            with nested_logging_context(ev.event_id):
+                await self._process_pulled_event(origin, ev)
+
+    async def _process_pulled_event(self, origin: str, event: EventBase) -> None:
+        """Process a single event that we have pulled from a remote server
+
+        Pulls in any events required to auth the event, persists the received event,
+        and notifies clients, if appropriate.
+
+        Assumes the event has already had its signatures and hashes checked.
+
+        This is somewhat equivalent to on_receive_pdu, but applies somewhat different
+        logic in the case that we are missing prev_events (in particular, it just
+        requests the state at that point, rather than triggering a get_missing_events) -
+        so is appropriate when we have pulled the event from a remote server, rather
+        than having it pushed to us.
+
+        Params:
+            origin: The server we received this event from
+            events: The received event
+        """
+        logger.info("Processing pulled event %s", event)
+
+        # these should not be outliers.
+        assert not event.internal_metadata.is_outlier()
+
+        event_id = event.event_id
+
+        existing = await self.store.get_event(
+            event_id, allow_none=True, allow_rejected=True
+        )
+        if existing:
+            if not existing.internal_metadata.is_outlier():
+                logger.info(
+                    "Ignoring received event %s which we have already seen",
+                    event_id,
+                )
+                return
+            logger.info("De-outliering event %s", event_id)
+
+        try:
+            self._sanity_check_event(event)
+        except SynapseError as err:
+            logger.warning("Event %s failed sanity check: %s", event_id, err)
+            return
+
+        try:
+            state = await self._resolve_state_at_missing_prevs(origin, event)
+            await self._process_received_pdu(origin, event, state=state)
+        except FederationError as e:
+            if e.code == 403:
+                logger.warning("Pulled event %s failed history check.", event_id)
+            else:
+                raise
+
+    async def _resolve_state_at_missing_prevs(
+        self, dest: str, event: EventBase
+    ) -> Optional[Iterable[EventBase]]:
+        """Calculate the state at an event with missing prev_events.
+
+        This is used when we have pulled a batch of events from a remote server, and
+        still don't have all the prev_events.
+
+        If we already have all the prev_events for `event`, this method does nothing.
+
+        Otherwise, the missing prevs become new backwards extremities, and we fall back
+        to asking the remote server for the state after each missing `prev_event`,
+        and resolving across them.
+
+        That's ok provided we then resolve the state against other bits of the DAG
+        before using it - in other words, that the received event `event` is not going
+        to become the only forwards_extremity in the room (which will ensure that you
+        can't just take over a room by sending an event, withholding its prev_events,
+        and declaring yourself to be an admin in the subsequent state request).
+
+        In other words: we should only call this method if `event` has been *pulled*
+        as part of a batch of missing prev events, or similar.
+
+        Params:
+            dest: the remote server to ask for state at the missing prevs. Typically,
+                this will be the server we got `event` from.
+            event: an event to check for missing prevs.
+
+        Returns:
+            if we already had all the prev events, `None`. Otherwise, returns a list of
+            the events in the state at `event`.
+        """
+        room_id = event.room_id
+        event_id = event.event_id
+
+        prevs = set(event.prev_event_ids())
+        seen = await self.store.have_events_in_timeline(prevs)
+        missing_prevs = prevs - seen
+
+        if not missing_prevs:
+            return None
+
+        logger.info(
+            "Event %s is missing prev_events %s: calculating state for a "
+            "backwards extremity",
+            event_id,
+            shortstr(missing_prevs),
+        )
+        # Calculate the state after each of the previous events, and
+        # resolve them to find the correct state at the current event.
+        event_map = {event_id: event}
+        try:
+            # Get the state of the events we know about
+            ours = await self.state_store.get_state_groups_ids(room_id, seen)
+
+            # state_maps is a list of mappings from (type, state_key) to event_id
+            state_maps: List[StateMap[str]] = list(ours.values())
+
+            # we don't need this any more, let's delete it.
+            del ours
+
+            # Ask the remote server for the states we don't
+            # know about
+            for p in missing_prevs:
+                logger.info("Requesting state after missing prev_event %s", p)
+
+                with nested_logging_context(p):
+                    # note that if any of the missing prevs share missing state or
+                    # auth events, the requests to fetch those events are deduped
+                    # by the get_pdu_cache in federation_client.
+                    remote_state = await self._get_state_after_missing_prev_event(
+                        dest, room_id, p
+                    )
+
+                    remote_state_map = {
+                        (x.type, x.state_key): x.event_id for x in remote_state
+                    }
+                    state_maps.append(remote_state_map)
+
+                    for x in remote_state:
+                        event_map[x.event_id] = x
+
+            room_version = await self.store.get_room_version_id(room_id)
+            state_map = await self._state_resolution_handler.resolve_events_with_store(
+                room_id,
+                room_version,
+                state_maps,
+                event_map,
+                state_res_store=StateResolutionStore(self.store),
+            )
+
+            # We need to give _process_received_pdu the actual state events
+            # rather than event ids, so generate that now.
+
+            # First though we need to fetch all the events that are in
+            # state_map, so we can build up the state below.
+            evs = await self.store.get_events(
+                list(state_map.values()),
+                get_prev_content=False,
+                redact_behaviour=EventRedactBehaviour.AS_IS,
+            )
+            event_map.update(evs)
+
+            state = [event_map[e] for e in state_map.values()]
+        except Exception:
+            logger.warning(
+                "Error attempting to resolve state at missing prev_events",
+                exc_info=True,
+            )
+            raise FederationError(
+                "ERROR",
+                403,
+                "We can't get valid state history.",
+                affected=event_id,
+            )
+        return state
+
     def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
@@ -1764,7 +1820,7 @@ class FederationHandler(BaseHandler):
                     p,
                 )
                 with nested_logging_context(p.event_id):
-                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p)
             except Exception as e:
                 logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -2375,6 +2431,7 @@ class FederationHandler(BaseHandler):
                 not event.internal_metadata.is_outlier()
                 and not backfilled
                 and not context.rejected
+                and (await self.store.get_min_depth(event.room_id)) <= event.depth
             ):
                 await self.action_generator.handle_push_actions_for_event(
                     event, context
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e1c544a3c9..4e8f7f1d85 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler):
             limit = 10
 
         async def handle_room(event: RoomsForUser):
-            d = {
+            d: JsonDict = {
                 "room_id": event.room_id,
                 "membership": event.membership,
                 "visibility": (
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8cf614136e..0ed59d757b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -56,6 +56,22 @@ login_counter = Counter(
 )
 
 
+def init_counters_for_auth_provider(auth_provider_id: str) -> None:
+    """Ensure the prometheus counters for the given auth provider are initialised
+
+    This fixes a problem where the counters are not reported for a given auth provider
+    until the user first logs in/registers.
+    """
+    for is_guest in (True, False):
+        login_counter.labels(guest=is_guest, auth_provider=auth_provider_id)
+        for shadow_banned in (True, False):
+            registration_counter.labels(
+                guest=is_guest,
+                shadow_banned=shadow_banned,
+                auth_provider=auth_provider_id,
+            )
+
+
 class LoginDict(TypedDict):
     device_id: str
     access_token: str
@@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler):
         self.session_lifetime = hs.config.session_lifetime
         self.access_token_lifetime = hs.config.access_token_lifetime
 
+        init_counters_for_auth_provider("")
+
     async def check_username(
         self,
         localpart: str,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba13196218..401b84aad1 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter
 from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.types import (
     JsonDict,
     Requester,
@@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.account_data_handler = hs.get_account_data_handler()
         self.event_auth_handler = hs.get_event_auth_handler()
 
-        self.member_linearizer = Linearizer(name="member")
+        self.member_linearizer: Linearizer = Linearizer(name="member")
 
         self.clock = hs.get_clock()
         self.spam_checker = hs.get_spam_checker()
@@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content.pop("displayname", None)
             content.pop("avatar_url", None)
 
+        if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN:
+            raise SynapseError(
+                400,
+                f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})",
+                errcode=Codes.BAD_JSON,
+            )
+
+        if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN:
+            raise SynapseError(
+                400,
+                f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})",
+                errcode=Codes.BAD_JSON,
+            )
+
         effective_membership_state = action
         if action in ["kick", "unban"]:
             effective_membership_state = "leave"
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index ac6cfc0da9..906985c754 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,12 +28,11 @@ from synapse.api.constants import (
     Membership,
     RoomTypes,
 )
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
 from synapse.events import EventBase
 from synapse.events.utils import format_event_for_client_v2
 from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -76,6 +75,9 @@ class _PaginationSession:
 
 
 class RoomSummaryHandler:
+    # A unique key used for pagination sessions for the room hierarchy endpoint.
+    _PAGINATION_SESSION_TYPE = "room_hierarchy_pagination"
+
     # The time a pagination session remains valid for.
     _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000
 
@@ -87,12 +89,6 @@ class RoomSummaryHandler:
         self._server_name = hs.hostname
         self._federation_client = hs.get_federation_client()
 
-        # A map of query information to the current pagination state.
-        #
-        # TODO Allow for multiple workers to share this data.
-        # TODO Expire pagination tokens.
-        self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
-
         # If a user tries to fetch the same page multiple times in quick succession,
         # only process the first attempt and return its result to subsequent requests.
         self._pagination_response_cache: ResponseCache[
@@ -102,21 +98,6 @@ class RoomSummaryHandler:
             "get_room_hierarchy",
         )
 
-    def _expire_pagination_sessions(self):
-        """Expire pagination session which are old."""
-        expire_before = (
-            self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS
-        )
-        to_expire = []
-
-        for key, value in self._pagination_sessions.items():
-            if value.creation_time_ms < expire_before:
-                to_expire.append(key)
-
-        for key in to_expire:
-            logger.debug("Expiring pagination session id %s", key)
-            del self._pagination_sessions[key]
-
     async def get_space_summary(
         self,
         requester: str,
@@ -327,18 +308,29 @@ class RoomSummaryHandler:
 
         # If this is continuing a previous session, pull the persisted data.
         if from_token:
-            self._expire_pagination_sessions()
+            try:
+                pagination_session = await self._store.get_session(
+                    session_type=self._PAGINATION_SESSION_TYPE,
+                    session_id=from_token,
+                )
+            except StoreError:
+                raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
 
-            pagination_key = _PaginationKey(
-                requested_room_id, suggested_only, max_depth, from_token
-            )
-            if pagination_key not in self._pagination_sessions:
+            # If the requester, room ID, suggested-only, or max depth were modified
+            # the session is invalid.
+            if (
+                requester != pagination_session["requester"]
+                or requested_room_id != pagination_session["room_id"]
+                or suggested_only != pagination_session["suggested_only"]
+                or max_depth != pagination_session["max_depth"]
+            ):
                 raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
 
             # Load the previous state.
-            pagination_session = self._pagination_sessions[pagination_key]
-            room_queue = pagination_session.room_queue
-            processed_rooms = pagination_session.processed_rooms
+            room_queue = [
+                _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"]
+            ]
+            processed_rooms = set(pagination_session["processed_rooms"])
         else:
             # The queue of rooms to process, the next room is last on the stack.
             room_queue = [_RoomQueueEntry(requested_room_id, ())]
@@ -456,13 +448,21 @@ class RoomSummaryHandler:
 
         # If there's additional data, generate a pagination token (and persist state).
         if room_queue:
-            next_batch = random_string(24)
-            result["next_batch"] = next_batch
-            pagination_key = _PaginationKey(
-                requested_room_id, suggested_only, max_depth, next_batch
-            )
-            self._pagination_sessions[pagination_key] = _PaginationSession(
-                self._clock.time_msec(), room_queue, processed_rooms
+            result["next_batch"] = await self._store.create_session(
+                session_type=self._PAGINATION_SESSION_TYPE,
+                value={
+                    # Information which must be identical across pagination.
+                    "requester": requester,
+                    "room_id": requested_room_id,
+                    "suggested_only": suggested_only,
+                    "max_depth": max_depth,
+                    # The stored state.
+                    "room_queue": [
+                        attr.astuple(room_entry) for room_entry in room_queue
+                    ],
+                    "processed_rooms": list(processed_rooms),
+                },
+                expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS,
             )
 
         return result
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 1b855a685c..0e6ebb574e 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.register import init_counters_for_auth_provider
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html, respond_with_redirect
@@ -213,6 +214,7 @@ class SsoHandler:
         p_id = p.idp_id
         assert p_id not in self._identity_providers
         self._identity_providers[p_id] = p
+        init_counters_for_auth_provider(p_id)
 
     def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
         """Get the configured identity providers"""
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 590642f510..86c3c7f0df 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1,5 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018, 2019 New Vector Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -31,6 +30,8 @@ from prometheus_client import Counter
 
 from synapse.api.constants import AccountDataTypes, EventTypes, Membership
 from synapse.api.filtering import FilterCollection
+from synapse.api.presence import UserPresenceState
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
@@ -86,20 +87,20 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
 SyncRequestKey = Tuple[Any, ...]
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class SyncConfig:
-    user = attr.ib(type=UserID)
-    filter_collection = attr.ib(type=FilterCollection)
-    is_guest = attr.ib(type=bool)
-    request_key = attr.ib(type=SyncRequestKey)
-    device_id = attr.ib(type=Optional[str])
+    user: UserID
+    filter_collection: FilterCollection
+    is_guest: bool
+    request_key: SyncRequestKey
+    device_id: Optional[str]
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class TimelineBatch:
-    prev_batch = attr.ib(type=StreamToken)
-    events = attr.ib(type=List[EventBase])
-    limited = attr.ib(type=bool)
+    prev_batch: StreamToken
+    events: List[EventBase]
+    limited: bool
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -113,16 +114,16 @@ class TimelineBatch:
 # if there are updates for it, which we check after the instance has been created.
 # This should not be a big deal because we update the notification counts afterwards as
 # well anyway.
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class JoinedSyncResult:
-    room_id = attr.ib(type=str)
-    timeline = attr.ib(type=TimelineBatch)
-    state = attr.ib(type=StateMap[EventBase])
-    ephemeral = attr.ib(type=List[JsonDict])
-    account_data = attr.ib(type=List[JsonDict])
-    unread_notifications = attr.ib(type=JsonDict)
-    summary = attr.ib(type=Optional[JsonDict])
-    unread_count = attr.ib(type=int)
+    room_id: str
+    timeline: TimelineBatch
+    state: StateMap[EventBase]
+    ephemeral: List[JsonDict]
+    account_data: List[JsonDict]
+    unread_notifications: JsonDict
+    summary: Optional[JsonDict]
+    unread_count: int
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -138,12 +139,12 @@ class JoinedSyncResult:
         )
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class ArchivedSyncResult:
-    room_id = attr.ib(type=str)
-    timeline = attr.ib(type=TimelineBatch)
-    state = attr.ib(type=StateMap[EventBase])
-    account_data = attr.ib(type=List[JsonDict])
+    room_id: str
+    timeline: TimelineBatch
+    state: StateMap[EventBase]
+    account_data: List[JsonDict]
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -152,37 +153,37 @@ class ArchivedSyncResult:
         return bool(self.timeline or self.state or self.account_data)
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class InvitedSyncResult:
-    room_id = attr.ib(type=str)
-    invite = attr.ib(type=EventBase)
+    room_id: str
+    invite: EventBase
 
     def __bool__(self) -> bool:
         """Invited rooms should always be reported to the client"""
         return True
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class KnockedSyncResult:
-    room_id = attr.ib(type=str)
-    knock = attr.ib(type=EventBase)
+    room_id: str
+    knock: EventBase
 
     def __bool__(self) -> bool:
         """Knocked rooms should always be reported to the client"""
         return True
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class GroupsSyncResult:
-    join = attr.ib(type=JsonDict)
-    invite = attr.ib(type=JsonDict)
-    leave = attr.ib(type=JsonDict)
+    join: JsonDict
+    invite: JsonDict
+    leave: JsonDict
 
     def __bool__(self) -> bool:
         return bool(self.join or self.invite or self.leave)
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DeviceLists:
     """
     Attributes:
@@ -190,27 +191,27 @@ class DeviceLists:
         left: List of user_ids whose devices we no longer track
     """
 
-    changed = attr.ib(type=Collection[str])
-    left = attr.ib(type=Collection[str])
+    changed: Collection[str]
+    left: Collection[str]
 
     def __bool__(self) -> bool:
         return bool(self.changed or self.left)
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _RoomChanges:
     """The set of room entries to include in the sync, plus the set of joined
     and left room IDs since last sync.
     """
 
-    room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
-    invited = attr.ib(type=List[InvitedSyncResult])
-    knocked = attr.ib(type=List[KnockedSyncResult])
-    newly_joined_rooms = attr.ib(type=List[str])
-    newly_left_rooms = attr.ib(type=List[str])
+    room_entries: List["RoomSyncResultBuilder"]
+    invited: List[InvitedSyncResult]
+    knocked: List[KnockedSyncResult]
+    newly_joined_rooms: List[str]
+    newly_left_rooms: List[str]
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class SyncResult:
     """
     Attributes:
@@ -230,18 +231,18 @@ class SyncResult:
         groups: Group updates, if any
     """
 
-    next_batch = attr.ib(type=StreamToken)
-    presence = attr.ib(type=List[JsonDict])
-    account_data = attr.ib(type=List[JsonDict])
-    joined = attr.ib(type=List[JoinedSyncResult])
-    invited = attr.ib(type=List[InvitedSyncResult])
-    knocked = attr.ib(type=List[KnockedSyncResult])
-    archived = attr.ib(type=List[ArchivedSyncResult])
-    to_device = attr.ib(type=List[JsonDict])
-    device_lists = attr.ib(type=DeviceLists)
-    device_one_time_keys_count = attr.ib(type=JsonDict)
-    device_unused_fallback_key_types = attr.ib(type=List[str])
-    groups = attr.ib(type=Optional[GroupsSyncResult])
+    next_batch: StreamToken
+    presence: List[UserPresenceState]
+    account_data: List[JsonDict]
+    joined: List[JoinedSyncResult]
+    invited: List[InvitedSyncResult]
+    knocked: List[KnockedSyncResult]
+    archived: List[ArchivedSyncResult]
+    to_device: List[JsonDict]
+    device_lists: DeviceLists
+    device_one_time_keys_count: JsonDict
+    device_unused_fallback_key_types: List[str]
+    groups: Optional[GroupsSyncResult]
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -701,7 +702,7 @@ class SyncHandler:
         name_id = state_ids.get((EventTypes.Name, ""))
         canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
 
-        summary = {}
+        summary: JsonDict = {}
         empty_ms = MemberSummary([], 0)
 
         # TODO: only send these when they change.
@@ -1843,6 +1844,9 @@ class SyncHandler:
         knocked = []
 
         for event in room_list:
+            if event.room_version_id not in KNOWN_ROOM_VERSIONS:
+                continue
+
             if event.membership == Membership.JOIN:
                 room_entries.append(
                     RoomSyncResultBuilder(
@@ -2076,21 +2080,23 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
-        for room_id, event_pos in joined_rooms:
-            if not event_pos.persisted_after(room_key):
-                joined_room_ids.add(room_id)
+        for joined_room in joined_rooms:
+            if not joined_room.event_pos.persisted_after(room_key):
+                joined_room_ids.add(joined_room.room_id)
                 continue
 
-            logger.info("User joined room after current token: %s", room_id)
+            logger.info("User joined room after current token: %s", joined_room.room_id)
 
             extrems = (
                 await self.store.get_forward_extremities_for_room_at_stream_ordering(
-                    room_id, event_pos.stream
+                    joined_room.room_id, joined_room.event_pos.stream
                 )
             )
-            users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
+            users_in_room = await self.state.get_current_users_in_room(
+                joined_room.room_id, extrems
+            )
             if user_id in users_in_room:
-                joined_room_ids.add(room_id)
+                joined_room_ids.add(joined_room.room_id)
 
         return frozenset(joined_room_ids)
 
@@ -2160,7 +2166,7 @@ def _calculate_state(
     return {event_id_to_key[e]: e for e in state_ids}
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class SyncResultBuilder:
     """Used to help build up a new SyncResult for a user
 
@@ -2172,33 +2178,33 @@ class SyncResultBuilder:
         joined_room_ids: List of rooms the user is joined to
 
         # The following mirror the fields in a sync response
-        presence (list)
-        account_data (list)
-        joined (list[JoinedSyncResult])
-        invited (list[InvitedSyncResult])
-        knocked (list[KnockedSyncResult])
-        archived (list[ArchivedSyncResult])
-        groups (GroupsSyncResult|None)
-        to_device (list)
+        presence
+        account_data
+        joined
+        invited
+        knocked
+        archived
+        groups
+        to_device
     """
 
-    sync_config = attr.ib(type=SyncConfig)
-    full_state = attr.ib(type=bool)
-    since_token = attr.ib(type=Optional[StreamToken])
-    now_token = attr.ib(type=StreamToken)
-    joined_room_ids = attr.ib(type=FrozenSet[str])
+    sync_config: SyncConfig
+    full_state: bool
+    since_token: Optional[StreamToken]
+    now_token: StreamToken
+    joined_room_ids: FrozenSet[str]
 
-    presence = attr.ib(type=List[JsonDict], default=attr.Factory(list))
-    account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
-    joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
-    invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
-    knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list))
-    archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
-    groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
-    to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
+    presence: List[UserPresenceState] = attr.Factory(list)
+    account_data: List[JsonDict] = attr.Factory(list)
+    joined: List[JoinedSyncResult] = attr.Factory(list)
+    invited: List[InvitedSyncResult] = attr.Factory(list)
+    knocked: List[KnockedSyncResult] = attr.Factory(list)
+    archived: List[ArchivedSyncResult] = attr.Factory(list)
+    groups: Optional[GroupsSyncResult] = None
+    to_device: List[JsonDict] = attr.Factory(list)
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class RoomSyncResultBuilder:
     """Stores information needed to create either a `JoinedSyncResult` or
     `ArchivedSyncResult`.
@@ -2214,10 +2220,10 @@ class RoomSyncResultBuilder:
         upto_token: Latest point to return events from.
     """
 
-    room_id = attr.ib(type=str)
-    rtype = attr.ib(type=str)
-    events = attr.ib(type=Optional[List[EventBase]])
-    newly_joined = attr.ib(type=bool)
-    full_state = attr.ib(type=bool)
-    since_token = attr.ib(type=Optional[StreamToken])
-    upto_token = attr.ib(type=StreamToken)
+    room_id: str
+    rtype: str
+    events: Optional[List[EventBase]]
+    newly_joined: bool
+    full_state: bool
+    since_token: Optional[StreamToken]
+    upto_token: StreamToken
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 4c3b669fae..13b0c61d2e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -34,3 +34,8 @@ class UIAuthSessionDataConstants:
     # used by validate_user_via_ui_auth to store the mxid of the user we are validating
     # for.
     REQUEST_USER_ID = "request_user_id"
+
+    # used during registration to store the registration token used (if required) so that:
+    # - we can prevent a token being used twice by one session
+    # - we can 'use up' the token after registration has successfully completed
+    REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token"
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 5414ce77d8..d3828dec6b 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -49,7 +49,7 @@ class UserInteractiveAuthChecker:
             clientip: The IP address of the client.
 
         Raises:
-            SynapseError if authentication failed
+            LoginError if authentication failed.
 
         Returns:
             The result of authentication (to pass back to the client?)
@@ -131,7 +131,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
             )
             if resp_body["success"]:
                 return True
-        raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+        raise LoginError(
+            401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED
+        )
 
 
 class _BaseThreepidAuthChecker:
@@ -191,7 +193,9 @@ class _BaseThreepidAuthChecker:
             raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
 
         if not threepid:
-            raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+            raise LoginError(
+                401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
+            )
 
         if threepid["medium"] != medium:
             raise LoginError(
@@ -237,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
         return await self._check_threepid("msisdn", authdict)
 
 
+class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
+    AUTH_TYPE = LoginType.REGISTRATION_TOKEN
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+        self.hs = hs
+        self._enabled = bool(hs.config.registration_requires_token)
+        self.store = hs.get_datastore()
+
+    def is_enabled(self) -> bool:
+        return self._enabled
+
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
+        if "token" not in authdict:
+            raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
+        if not isinstance(authdict["token"], str):
+            raise LoginError(
+                400, "Registration token must be a string", Codes.INVALID_PARAM
+            )
+        if "session" not in authdict:
+            raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
+
+        # Get these here to avoid cyclic dependencies
+        from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+
+        auth_handler = self.hs.get_auth_handler()
+
+        session = authdict["session"]
+        token = authdict["token"]
+
+        # If the LoginType.REGISTRATION_TOKEN stage has already been completed,
+        # return early to avoid incrementing `pending` again.
+        stored_token = await auth_handler.get_session_data(
+            session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
+        )
+        if stored_token:
+            if token != stored_token:
+                raise LoginError(
+                    400, "Registration token has changed", Codes.INVALID_PARAM
+                )
+            else:
+                return token
+
+        if await self.store.registration_token_is_valid(token):
+            # Increment pending counter, so that if token has limited uses it
+            # can't be used up by someone else in the meantime.
+            await self.store.set_registration_token_pending(token)
+            # Store the token in the UIA session, so that once registration
+            # is complete `completed` can be incremented.
+            await auth_handler.set_session_data(
+                session,
+                UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+                token,
+            )
+            # The token will be stored as the result of the authentication stage
+            # in ui_auth_sessions_credentials. This allows the pending counter
+            # for tokens to be decremented when expired sessions are deleted.
+            return token
+        else:
+            raise LoginError(
+                401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
+            )
+
+
 INTERACTIVE_AUTH_CHECKERS = [
     DummyAuthChecker,
     TermsAuthChecker,
     RecaptchaAuthChecker,
     EmailIdentityAuthChecker,
     MsisdnAuthChecker,
+    RegistrationTokenAuthChecker,
 ]
 """A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 55ea97a07f..9a2684aca4 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,8 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
 from synapse.http.server import DirectServeJsonResource
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class AdditionalResource(DirectServeJsonResource):
     """Resource wrapper for additional_resources
@@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource):
     and exception handling.
     """
 
-    def __init__(self, hs, handler):
+    def __init__(self, hs: "HomeServer", handler):
         """Initialise AdditionalResource
 
         The ``handler`` should return a deferred which completes when it has
@@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource):
         ``request.write()``, and call ``request.finish()``.
 
         Args:
-            hs (synapse.server.HomeServer): homeserver
+            hs: homeserver
             handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
                 function to be called to handle the request.
         """
         super().__init__()
         self._handler = handler
 
-    def _async_render(self, request):
+    def _async_render(self, request: Request):
         # Cheekily pass the result straight through, so we don't need to worry
         # if its an awaitable or not.
         return self._handler(request)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index b8ed4ec905..f68646fd0d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import time
-from typing import List
+from typing import Callable, Dict, List
 
 import attr
 
@@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable
 
 logger = logging.getLogger(__name__)
 
-SERVER_CACHE = {}
+SERVER_CACHE: Dict[bytes, List["Server"]] = {}
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
 class Server:
     """
     Our record of an individual server which can be tried to reach a destination.
 
     Attributes:
-        host (bytes): target hostname
-        port (int):
-        priority (int):
-        weight (int):
-        expires (int): when the cache should expire this record - in *seconds* since
+        host: target hostname
+        port:
+        priority:
+        weight:
+        expires: when the cache should expire this record - in *seconds* since
             the epoch
     """
 
-    host = attr.ib()
-    port = attr.ib()
-    priority = attr.ib(default=0)
-    weight = attr.ib(default=0)
-    expires = attr.ib(default=0)
+    host: bytes
+    port: int
+    priority: int = 0
+    weight: int = 0
+    expires: int = 0
 
 
-def _sort_server_list(server_list):
+def _sort_server_list(server_list: List[Server]) -> List[Server]:
     """Given a list of SRV records sort them into priority order and shuffle
     each priority with the given weight.
     """
-    priority_map = {}
+    priority_map: Dict[int, List[Server]] = {}
 
     for server in server_list:
         priority_map.setdefault(server.priority, []).append(server)
@@ -103,11 +103,16 @@ class SrvResolver:
 
     Args:
         dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl
-        cache (dict): cache object
-        get_time (callable): clock implementation. Should return seconds since the epoch
+        cache: cache object
+        get_time: clock implementation. Should return seconds since the epoch
     """
 
-    def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
+    def __init__(
+        self,
+        dns_client=client,
+        cache: Dict[bytes, List[Server]] = SERVER_CACHE,
+        get_time: Callable[[], float] = time.time,
+    ):
         self._dns_client = dns_client
         self._cache = cache
         self._get_time = get_time
@@ -116,7 +121,7 @@ class SrvResolver:
         """Look up a SRV record
 
         Args:
-            service_name (bytes): record to look up
+            service_name: record to look up
 
         Returns:
             a list of the SRV records, or an empty list if none found
@@ -158,7 +163,7 @@ class SrvResolver:
             and answers[0].payload
             and answers[0].payload.target == dns.Name(b".")
         ):
-            raise ConnectError("Service %s unavailable" % service_name)
+            raise ConnectError(f"Service {service_name!r} unavailable")
 
         servers = []
 
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a3f31452d0..6fd88bde20 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase):
             raise ValueError(f"Invalid URI {uri!r}")
 
         parsed_uri = URI.fromBytes(uri)
-        pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+        pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}"
         request_path = parsed_uri.originForm
 
         should_skip_proxy = False
@@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase):
                 )
             # Cache *all* connections under the same key, since we are only
             # connecting to a single destination, the proxy:
-            pool_key = ("http-proxy", self.http_proxy_endpoint)
+            pool_key = "http-proxy"
             endpoint = self.http_proxy_endpoint
             request_path = uri
         elif (
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2d2ed229e2..b11fa6393b 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -32,6 +32,7 @@ from twisted.internet import defer
 from twisted.web.resource import IResource
 
 from synapse.events import EventBase
+from synapse.events.presence_router import PresenceRouter
 from synapse.http.client import SimpleHttpClient
 from synapse.http.server import (
     DirectServeHtmlResource,
@@ -57,6 +58,8 @@ This package defines the 'stable' API which can be used by extension modules whi
 are loaded into Synapse.
 """
 
+PRESENCE_ALL_USERS = PresenceRouter.ALL_USERS
+
 __all__ = [
     "errors",
     "make_deferred_yieldable",
@@ -70,6 +73,7 @@ __all__ = [
     "DirectServeHtmlResource",
     "DirectServeJsonResource",
     "ModuleApi",
+    "PRESENCE_ALL_USERS",
 ]
 
 logger = logging.getLogger(__name__)
@@ -112,6 +116,7 @@ class ModuleApi:
         self._spam_checker = hs.get_spam_checker()
         self._account_validity_handler = hs.get_account_validity_handler()
         self._third_party_event_rules = hs.get_third_party_event_rules()
+        self._presence_router = hs.get_presence_router()
 
     #################################################################################
     # The following methods should only be called during the module's initialisation.
@@ -131,6 +136,11 @@ class ModuleApi:
         """Registers callbacks for third party event rules capabilities."""
         return self._third_party_event_rules.register_third_party_rules_callbacks
 
+    @property
+    def register_presence_router_callbacks(self):
+        """Registers callbacks for presence router capabilities."""
+        return self._presence_router.register_presence_router_callbacks
+
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
 
diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html
index 63944dc608..b3db06ef97 100644
--- a/synapse/res/templates/recaptcha.html
+++ b/synapse/res/templates/recaptcha.html
@@ -16,6 +16,9 @@ function captchaDone() {
 <body>
 <form id="registrationForm" method="post" action="{{ myurl }}">
     <div>
+        {% if error is defined %}
+            <p class="error"><strong>Error: {{ error }}</strong></p>
+        {% endif %}
         <p>
         Hello! We need to prevent computer programs and other automated
         things from creating accounts on this server.
diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html
new file mode 100644
index 0000000000..4577ce1702
--- /dev/null
+++ b/synapse/res/templates/registration_token.html
@@ -0,0 +1,23 @@
+<html>
+<head>
+<title>Authentication</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+</head>
+<body>
+<form id="registrationForm" method="post" action="{{ myurl }}">
+    <div>
+        {% if error is defined %}
+            <p class="error"><strong>Error: {{ error }}</strong></p>
+        {% endif %}
+        <p>
+           Please enter a registration token.
+        </p>
+        <input type="hidden" name="session" value="{{ session }}" />
+        <input type="text" name="token" />
+        <input type="submit" value="Authenticate" />
+    </div>
+</form>
+</body>
+</html>
diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html
index dfef9897ee..369ff446d2 100644
--- a/synapse/res/templates/terms.html
+++ b/synapse/res/templates/terms.html
@@ -8,6 +8,9 @@
 <body>
 <form id="registrationForm" method="post" action="{{ myurl }}">
     <div>
+        {% if error is defined %}
+            <p class="error"><strong>Error: {{ error }}</strong></p>
+        {% endif %}
         <p>
             Please click the button below if you agree to the
             <a href="{{ terms_url }}">privacy policy of this homeserver.</a>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index d5862a4da4..6e1c8736e1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -36,7 +36,11 @@ from synapse.rest.admin.event_reports import (
 )
 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
-from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
+from synapse.rest.admin.registration_tokens import (
+    ListRegistrationTokensRestServlet,
+    NewRegistrationTokenRestServlet,
+    RegistrationTokenRestServlet,
+)
 from synapse.rest.admin.rooms import (
     DeleteRoomRestServlet,
     ForwardExtremitiesRestServlet,
@@ -47,7 +51,6 @@ from synapse.rest.admin.rooms import (
     RoomMembersRestServlet,
     RoomRestServlet,
     RoomStateRestServlet,
-    ShutdownRoomRestServlet,
 )
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
 from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
@@ -220,7 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomMembersRestServlet(hs).register(http_server)
     DeleteRoomRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
-    PurgeRoomServlet(hs).register(http_server)
     SendServerNoticeServlet(hs).register(http_server)
     VersionServlet(hs).register(http_server)
     UserAdminServlet(hs).register(http_server)
@@ -241,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomEventContextServlet(hs).register(http_server)
     RateLimitRestServlet(hs).register(http_server)
     UsernameAvailableRestServlet(hs).register(http_server)
+    ListRegistrationTokensRestServlet(hs).register(http_server)
+    NewRegistrationTokenRestServlet(hs).register(http_server)
+    RegistrationTokenRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(
@@ -253,7 +258,6 @@ def register_servlets_for_client_rest_resource(
     PurgeHistoryRestServlet(hs).register(http_server)
     ResetPasswordRestServlet(hs).register(http_server)
     SearchUsersRestServlet(hs).register(http_server)
-    ShutdownRoomRestServlet(hs).register(http_server)
     UserRegisterServlet(hs).register(http_server)
     DeleteGroupAdminRestServlet(hs).register(http_server)
     AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
deleted file mode 100644
index 2365ff7a0f..0000000000
--- a/synapse/rest/admin/purge_room_servlet.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from typing import TYPE_CHECKING, Tuple
-
-from synapse.http.servlet import (
-    RestServlet,
-    assert_params_in_dict,
-    parse_json_object_from_request,
-)
-from synapse.http.site import SynapseRequest
-from synapse.rest.admin import assert_requester_is_admin
-from synapse.rest.admin._base import admin_patterns
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
-    from synapse.server import HomeServer
-
-
-class PurgeRoomServlet(RestServlet):
-    """Servlet which will remove all trace of a room from the database
-
-    POST /_synapse/admin/v1/purge_room
-    {
-        "room_id": "!room:id"
-    }
-
-    returns:
-
-    {}
-    """
-
-    PATTERNS = admin_patterns("/purge_room$")
-
-    def __init__(self, hs: "HomeServer"):
-        self.hs = hs
-        self.auth = hs.get_auth()
-        self.pagination_handler = hs.get_pagination_handler()
-
-    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        await assert_requester_is_admin(self.auth, request)
-
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ("room_id",))
-
-        await self.pagination_handler.purge_room(body["room_id"])
-
-        return 200, {}
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
new file mode 100644
index 0000000000..5a1c929d85
--- /dev/null
+++ b/synapse/rest/admin/registration_tokens.py
@@ -0,0 +1,321 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import string
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+    RestServlet,
+    parse_boolean,
+    parse_json_object_from_request,
+)
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ListRegistrationTokensRestServlet(RestServlet):
+    """List registration tokens.
+
+    To list all tokens:
+
+        GET /_synapse/admin/v1/registration_tokens
+
+        200 OK
+
+        {
+            "registration_tokens": [
+                {
+                    "token": "abcd",
+                    "uses_allowed": 3,
+                    "pending": 0,
+                    "completed": 1,
+                    "expiry_time": null
+                },
+                {
+                    "token": "wxyz",
+                    "uses_allowed": null,
+                    "pending": 0,
+                    "completed": 9,
+                    "expiry_time": 1625394937000
+                }
+            ]
+        }
+
+    The optional query parameter `valid` can be used to filter the response.
+    If it is `true`, only valid tokens are returned. If it is `false`, only
+    tokens that have expired or have had all uses exhausted are returned.
+    If it is omitted, all tokens are returned regardless of validity.
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+        valid = parse_boolean(request, "valid")
+        token_list = await self.store.get_registration_tokens(valid)
+        return 200, {"registration_tokens": token_list}
+
+
+class NewRegistrationTokenRestServlet(RestServlet):
+    """Create a new registration token.
+
+    For example, to create a token specifying some fields:
+
+        POST /_synapse/admin/v1/registration_tokens/new
+
+        {
+            "token": "defg",
+            "uses_allowed": 1
+        }
+
+        200 OK
+
+        {
+            "token": "defg",
+            "uses_allowed": 1,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": null
+        }
+
+    Defaults are used for any fields not specified.
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens/new$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        # A string of all the characters allowed to be in a registration_token
+        self.allowed_chars = string.ascii_letters + string.digits + "-_"
+        self.allowed_chars_set = set(self.allowed_chars)
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self.auth, request)
+        body = parse_json_object_from_request(request)
+
+        if "token" in body:
+            token = body["token"]
+            if not isinstance(token, str):
+                raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM)
+            if not (0 < len(token) <= 64):
+                raise SynapseError(
+                    400,
+                    "token must not be empty and must not be longer than 64 characters",
+                    Codes.INVALID_PARAM,
+                )
+            if not set(token).issubset(self.allowed_chars_set):
+                raise SynapseError(
+                    400,
+                    "token must consist only of characters matched by the regex [A-Za-z0-9-_]",
+                    Codes.INVALID_PARAM,
+                )
+
+        else:
+            # Get length of token to generate (default is 16)
+            length = body.get("length", 16)
+            if not isinstance(length, int):
+                raise SynapseError(
+                    400, "length must be an integer", Codes.INVALID_PARAM
+                )
+            if not (0 < length <= 64):
+                raise SynapseError(
+                    400,
+                    "length must be greater than zero and not greater than 64",
+                    Codes.INVALID_PARAM,
+                )
+
+            # Generate token
+            token = await self.store.generate_registration_token(
+                length, self.allowed_chars
+            )
+
+        uses_allowed = body.get("uses_allowed", None)
+        if not (
+            uses_allowed is None
+            or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+        ):
+            raise SynapseError(
+                400,
+                "uses_allowed must be a non-negative integer or null",
+                Codes.INVALID_PARAM,
+            )
+
+        expiry_time = body.get("expiry_time", None)
+        if not isinstance(expiry_time, (int, type(None))):
+            raise SynapseError(
+                400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+            )
+        if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+            raise SynapseError(
+                400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+            )
+
+        created = await self.store.create_registration_token(
+            token, uses_allowed, expiry_time
+        )
+        if not created:
+            raise SynapseError(
+                400, f"Token already exists: {token}", Codes.INVALID_PARAM
+            )
+
+        resp = {
+            "token": token,
+            "uses_allowed": uses_allowed,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": expiry_time,
+        }
+        return 200, resp
+
+
+class RegistrationTokenRestServlet(RestServlet):
+    """Retrieve, update, or delete the given token.
+
+    For example,
+
+    to retrieve a token:
+
+        GET /_synapse/admin/v1/registration_tokens/abcd
+
+        200 OK
+
+        {
+            "token": "abcd",
+            "uses_allowed": 3,
+            "pending": 0,
+            "completed": 1,
+            "expiry_time": null
+        }
+
+
+    to update a token:
+
+        PUT /_synapse/admin/v1/registration_tokens/defg
+
+        {
+            "uses_allowed": 5,
+            "expiry_time": 4781243146000
+        }
+
+        200 OK
+
+        {
+            "token": "defg",
+            "uses_allowed": 5,
+            "pending": 0,
+            "completed": 0,
+            "expiry_time": 4781243146000
+        }
+
+
+    to delete a token:
+
+        DELETE /_synapse/admin/v1/registration_tokens/wxyz
+
+        200 OK
+
+        {}
+    """
+
+    PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+
+    async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+        """Retrieve a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+        token_info = await self.store.get_one_registration_token(token)
+
+        # If no result return a 404
+        if token_info is None:
+            raise NotFoundError(f"No such registration token: {token}")
+
+        return 200, token_info
+
+    async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
+        """Update a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+        body = parse_json_object_from_request(request)
+        new_attributes = {}
+
+        # Only add uses_allowed to new_attributes if it is present and valid
+        if "uses_allowed" in body:
+            uses_allowed = body["uses_allowed"]
+            if not (
+                uses_allowed is None
+                or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+            ):
+                raise SynapseError(
+                    400,
+                    "uses_allowed must be a non-negative integer or null",
+                    Codes.INVALID_PARAM,
+                )
+            new_attributes["uses_allowed"] = uses_allowed
+
+        if "expiry_time" in body:
+            expiry_time = body["expiry_time"]
+            if not isinstance(expiry_time, (int, type(None))):
+                raise SynapseError(
+                    400, "expiry_time must be an integer or null", Codes.INVALID_PARAM
+                )
+            if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+                raise SynapseError(
+                    400, "expiry_time must not be in the past", Codes.INVALID_PARAM
+                )
+            new_attributes["expiry_time"] = expiry_time
+
+        if len(new_attributes) == 0:
+            # Nothing to update, get token info to return
+            token_info = await self.store.get_one_registration_token(token)
+        else:
+            token_info = await self.store.update_registration_token(
+                token, new_attributes
+            )
+
+        # If no result return a 404
+        if token_info is None:
+            raise NotFoundError(f"No such registration token: {token}")
+
+        return 200, token_info
+
+    async def on_DELETE(
+        self, request: SynapseRequest, token: str
+    ) -> Tuple[int, JsonDict]:
+        """Delete a registration token."""
+        await assert_requester_is_admin(self.auth, request)
+
+        if await self.store.delete_registration_token(token):
+            return 200, {}
+
+        raise NotFoundError(f"No such registration token: {token}")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 975c28b225..ad83d4b54c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -46,41 +46,6 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class ShutdownRoomRestServlet(RestServlet):
-    """Shuts down a room by removing all local users from the room and blocking
-    all future invites and joins to the room. Any local aliases will be repointed
-    to a new room created by `new_room_user_id` and kicked users will be auto
-    joined to the new room.
-    """
-
-    PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
-
-    def __init__(self, hs: "HomeServer"):
-        self.hs = hs
-        self.auth = hs.get_auth()
-        self.room_shutdown_handler = hs.get_room_shutdown_handler()
-
-    async def on_POST(
-        self, request: SynapseRequest, room_id: str
-    ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
-
-        content = parse_json_object_from_request(request)
-        assert_params_in_dict(content, ["new_room_user_id"])
-
-        ret = await self.room_shutdown_handler.shutdown_room(
-            room_id=room_id,
-            new_room_user_id=content["new_room_user_id"],
-            new_room_name=content.get("room_name"),
-            message=content.get("message"),
-            requester_user_id=requester.user.to_string(),
-            block=True,
-        )
-
-        return (200, ret)
-
-
 class DeleteRoomRestServlet(RestServlet):
     """Delete a room from server.
 
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3c8a0c6883..c1a1ba645e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
         if not isinstance(deactivate, bool):
             raise SynapseError(400, "'deactivated' parameter is not of type boolean")
 
-        # convert into List[Tuple[str, str]]
+        # convert List[Dict[str, str]] into Set[Tuple[str, str]]
         if external_ids is not None:
-            new_external_ids = []
-            for external_id in external_ids:
-                new_external_ids.append(
-                    (external_id["auth_provider"], external_id["external_id"])
-                )
+            new_external_ids = {
+                (external_id["auth_provider"], external_id["external_id"])
+                for external_id in external_ids
+            }
+
+        # convert List[Dict[str, str]] into Set[Tuple[str, str]]
+        if threepids is not None:
+            new_threepids = {
+                (threepid["medium"], threepid["address"]) for threepid in threepids
+            }
 
         if user:  # modify user
             if "displayname" in body:
@@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
                 )
 
             if threepids is not None:
-                # remove old threepids from user
-                old_threepids = await self.store.user_get_threepids(user_id)
-                for threepid in old_threepids:
+                # get changed threepids (added and removed)
+                # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
+                cur_threepids = {
+                    (threepid["medium"], threepid["address"])
+                    for threepid in await self.store.user_get_threepids(user_id)
+                }
+                add_threepids = new_threepids - cur_threepids
+                del_threepids = cur_threepids - new_threepids
+
+                # remove old threepids
+                for medium, address in del_threepids:
                     try:
                         await self.auth_handler.delete_threepid(
-                            user_id, threepid["medium"], threepid["address"], None
+                            user_id, medium, address, None
                         )
                     except Exception:
                         logger.exception("Failed to remove threepids")
                         raise SynapseError(500, "Failed to remove threepids")
 
-                # add new threepids to user
+                # add new threepids
                 current_time = self.hs.get_clock().time_msec()
-                for threepid in threepids:
+                for medium, address in add_threepids:
                     await self.auth_handler.add_threepid(
-                        user_id, threepid["medium"], threepid["address"], current_time
+                        user_id, medium, address, current_time
                     )
 
             if external_ids is not None:
                 # get changed external_ids (added and removed)
-                cur_external_ids = await self.store.get_external_ids_by_user(user_id)
-                add_external_ids = set(new_external_ids) - set(cur_external_ids)
-                del_external_ids = set(cur_external_ids) - set(new_external_ids)
+                cur_external_ids = set(
+                    await self.store.get_external_ids_by_user(user_id)
+                )
+                add_external_ids = new_external_ids - cur_external_ids
+                del_external_ids = cur_external_ids - new_external_ids
 
                 # remove old external_ids
                 for auth_provider, external_id in del_external_ids:
@@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
 
             if threepids is not None:
                 current_time = self.hs.get_clock().time_msec()
-                for threepid in threepids:
+                for medium, address in new_threepids:
                     await self.auth_handler.add_threepid(
-                        user_id, threepid["medium"], threepid["address"], current_time
+                        user_id, medium, address, current_time
                     )
                     if (
                         self.hs.config.email_enable_notifs
@@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
                             kind="email",
                             app_id="m.email",
                             app_display_name="Email Notifications",
-                            device_display_name=threepid["address"],
-                            pushkey=threepid["address"],
+                            device_display_name=address,
+                            pushkey=address,
                             lang=None,  # We don't know a user's language here
                             data={},
                         )
diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py
index 3ebe401861..6c24b96c54 100644
--- a/synapse/rest/client/account_validity.py
+++ b/synapse/rest/client/account_validity.py
@@ -13,24 +13,27 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.errors import SynapseError
-from synapse.http.server import respond_with_html
-from synapse.http.servlet import RestServlet
+from twisted.web.server import Request
+
+from synapse.http.server import HttpServer, respond_with_html
+from synapse.http.servlet import RestServlet, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AccountValidityRenewServlet(RestServlet):
     PATTERNS = client_patterns("/account_validity/renew$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.hs = hs
@@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet):
             hs.config.account_validity.account_validity_invalid_token_template
         )
 
-    async def on_GET(self, request):
-        if b"token" not in request.args:
-            raise SynapseError(400, "Missing renewal token")
-        renewal_token = request.args[b"token"][0]
+    async def on_GET(self, request: Request) -> None:
+        renewal_token = parse_string(request, "token", required=True)
 
         (
             token_valid,
             token_stale,
             expiration_ts,
-        ) = await self.account_activity_handler.renew_account(
-            renewal_token.decode("utf8")
-        )
+        ) = await self.account_activity_handler.renew_account(renewal_token)
 
         if token_valid:
             status_code = 200
@@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet):
 class AccountValiditySendMailServlet(RestServlet):
     PATTERNS = client_patterns("/account_validity/send_mail$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.hs = hs
@@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet):
             hs.config.account_validity.account_validity_renew_by_email_enabled
         )
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
         user_id = requester.user.to_string()
         await self.account_activity_handler.send_renewal_email_to_user(user_id)
@@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     AccountValidityRenewServlet(hs).register(http_server)
     AccountValiditySendMailServlet(hs).register(http_server)
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 6ea1b50a62..91800c0278 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -16,7 +16,7 @@ import logging
 from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError
+from synapse.api.errors import LoginError, SynapseError
 from synapse.api.urls import CLIENT_API_PREFIX
 from synapse.http.server import respond_with_html
 from synapse.http.servlet import RestServlet, parse_string
@@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
+        self.registration_token_template = hs.config.registration_token_template
         self.success_template = hs.config.fallback_success_template
 
     async def on_GET(self, request, stagetype):
@@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet):
             # re-authenticate with their SSO provider.
             html = await self.auth_handler.start_sso_ui_auth(request, session)
 
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            html = self.registration_token_template.render(
+                session=session,
+                myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+            )
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
@@ -95,29 +102,32 @@ class AuthRestServlet(RestServlet):
 
             authdict = {"response": response, "session": session}
 
-            success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, request.getClientIP()
-            )
-
-            if success:
-                html = self.success_template.render()
-            else:
+            try:
+                await self.auth_handler.add_oob_auth(
+                    LoginType.RECAPTCHA, authdict, request.getClientIP()
+                )
+            except LoginError as e:
+                # Authentication failed, let user try again
                 html = self.recaptcha_template.render(
                     session=session,
                     myurl="%s/r0/auth/%s/fallback/web"
                     % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
                     sitekey=self.hs.config.recaptcha_public_key,
+                    error=e.msg,
                 )
+            else:
+                # No LoginError was raised, so authentication was successful
+                html = self.success_template.render()
+
         elif stagetype == LoginType.TERMS:
             authdict = {"session": session}
 
-            success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, request.getClientIP()
-            )
-
-            if success:
-                html = self.success_template.render()
-            else:
+            try:
+                await self.auth_handler.add_oob_auth(
+                    LoginType.TERMS, authdict, request.getClientIP()
+                )
+            except LoginError as e:
+                # Authentication failed, let user try again
                 html = self.terms_template.render(
                     session=session,
                     terms_url="%s_matrix/consent?v=%s"
@@ -127,10 +137,33 @@ class AuthRestServlet(RestServlet):
                     ),
                     myurl="%s/r0/auth/%s/fallback/web"
                     % (CLIENT_API_PREFIX, LoginType.TERMS),
+                    error=e.msg,
                 )
+            else:
+                # No LoginError was raised, so authentication was successful
+                html = self.success_template.render()
+
         elif stagetype == LoginType.SSO:
             # The SSO fallback workflow should not post here,
             raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
+
+        elif stagetype == LoginType.REGISTRATION_TOKEN:
+            token = parse_string(request, "token", required=True)
+            authdict = {"session": session, "token": token}
+
+            try:
+                await self.auth_handler.add_oob_auth(
+                    LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
+                )
+            except LoginError as e:
+                html = self.registration_token_template.render(
+                    session=session,
+                    myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web",
+                    error=e.msg,
+                )
+            else:
+                html = self.success_template.render()
+
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 88e3aac797..65b3b5ce2c 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -15,6 +15,7 @@ import logging
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict
@@ -61,8 +62,19 @@ class CapabilitiesRestServlet(RestServlet):
                 "org.matrix.msc3244.room_capabilities"
             ] = MSC3244_CAPABILITIES
 
+        if self.config.experimental.msc3283_enabled:
+            response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
+                "enabled": self.config.enable_set_displayname
+            }
+            response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
+                "enabled": self.config.enable_set_avatar_url
+            }
+            response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
+                "enabled": self.config.enable_3pid_changes
+            }
+
         return 200, response
 
 
-def register_servlets(hs: "HomeServer", http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index ffa075c8e5..ee247e3d1e 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -12,8 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
 
 from synapse.api.errors import (
     AuthError,
@@ -22,14 +24,19 @@ from synapse.api.errors import (
     NotFoundError,
     SynapseError,
 )
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ClientDirectoryServer(hs).register(http_server)
     ClientDirectoryListServer(hs).register(http_server)
     ClientAppserviceDirectoryListServer(hs).register(http_server)
@@ -38,21 +45,23 @@ def register_servlets(hs, http_server):
 class ClientDirectoryServer(RestServlet):
     PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.store = hs.get_datastore()
         self.directory_handler = hs.get_directory_handler()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_alias):
-        room_alias = RoomAlias.from_string(room_alias)
+    async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]:
+        room_alias_obj = RoomAlias.from_string(room_alias)
 
-        res = await self.directory_handler.get_association(room_alias)
+        res = await self.directory_handler.get_association(room_alias_obj)
 
         return 200, res
 
-    async def on_PUT(self, request, room_alias):
-        room_alias = RoomAlias.from_string(room_alias)
+    async def on_PUT(
+        self, request: SynapseRequest, room_alias: str
+    ) -> Tuple[int, JsonDict]:
+        room_alias_obj = RoomAlias.from_string(room_alias)
 
         content = parse_json_object_from_request(request)
         if "room_id" not in content:
@@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet):
             )
 
         logger.debug("Got content: %s", content)
-        logger.debug("Got room name: %s", room_alias.to_string())
+        logger.debug("Got room name: %s", room_alias_obj.to_string())
 
         room_id = content["room_id"]
         servers = content["servers"] if "servers" in content else None
@@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet):
         requester = await self.auth.get_user_by_req(request)
 
         await self.directory_handler.create_association(
-            requester, room_alias, room_id, servers
+            requester, room_alias_obj, room_id, servers
         )
 
         return 200, {}
 
-    async def on_DELETE(self, request, room_alias):
+    async def on_DELETE(
+        self, request: SynapseRequest, room_alias: str
+    ) -> Tuple[int, JsonDict]:
+        room_alias_obj = RoomAlias.from_string(room_alias)
+
         try:
             service = self.auth.get_appservice_by_req(request)
-            room_alias = RoomAlias.from_string(room_alias)
             await self.directory_handler.delete_appservice_association(
-                service, room_alias
+                service, room_alias_obj
             )
             logger.info(
                 "Application service at %s deleted alias %s",
                 service.url,
-                room_alias.to_string(),
+                room_alias_obj.to_string(),
             )
             return 200, {}
         except InvalidClientCredentialsError:
@@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         user = requester.user
 
-        room_alias = RoomAlias.from_string(room_alias)
-
-        await self.directory_handler.delete_association(requester, room_alias)
+        await self.directory_handler.delete_association(requester, room_alias_obj)
 
         logger.info(
-            "User %s deleted alias %s", user.to_string(), room_alias.to_string()
+            "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string()
         )
 
         return 200, {}
@@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet):
 class ClientDirectoryListServer(RestServlet):
     PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.store = hs.get_datastore()
         self.directory_handler = hs.get_directory_handler()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         room = await self.store.get_room(room_id)
         if room is None:
             raise NotFoundError("Unknown room")
 
         return 200, {"visibility": "public" if room["is_public"] else "private"}
 
-    async def on_PUT(self, request, room_id):
+    async def on_PUT(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         content = parse_json_object_from_request(request)
@@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet):
 
         return 200, {}
 
-    async def on_DELETE(self, request, room_id):
+    async def on_DELETE(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         await self.directory_handler.edit_published_room_list(
@@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet):
         "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.store = hs.get_datastore()
         self.directory_handler = hs.get_directory_handler()
         self.auth = hs.get_auth()
 
-    def on_PUT(self, request, network_id, room_id):
+    async def on_PUT(
+        self, request: SynapseRequest, network_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         content = parse_json_object_from_request(request)
         visibility = content.get("visibility", "public")
-        return self._edit(request, network_id, room_id, visibility)
+        return await self._edit(request, network_id, room_id, visibility)
 
-    def on_DELETE(self, request, network_id, room_id):
-        return self._edit(request, network_id, room_id, "private")
+    async def on_DELETE(
+        self, request: SynapseRequest, network_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        return await self._edit(request, network_id, room_id, "private")
 
-    async def _edit(self, request, network_id, room_id, visibility):
+    async def _edit(
+        self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if not requester.app_service:
             raise AuthError(
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index d0d9d30d40..012491f597 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,8 +15,9 @@
 # limitations under the License.
 
 import logging
+from typing import Any
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import InvalidAPICallError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     parse_integer,
@@ -163,6 +164,19 @@ class KeyQueryServlet(RestServlet):
         device_id = requester.device_id
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
+
+        device_keys = body.get("device_keys")
+        if not isinstance(device_keys, dict):
+            raise InvalidAPICallError("'device_keys' must be a JSON object")
+
+        def is_list_of_strings(values: Any) -> bool:
+            return isinstance(values, list) and all(isinstance(v, str) for v in values)
+
+        if any(not is_list_of_strings(keys) for keys in device_keys.values()):
+            raise InvalidAPICallError(
+                "'device_keys' values must be a list of strings",
+            )
+
         result = await self.e2e_keys_handler.query_devices(
             body, timeout, user_id, device_id
         )
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0c8d8967b7..11d07776b2 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -104,6 +104,12 @@ class LoginRestServlet(RestServlet):
             burst_count=self.hs.config.rc_login_account.burst_count,
         )
 
+        # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
+        # The reason for this is to ensure that the auth_provider_ids are registered
+        # with SsoHandler, which in turn ensures that the login/registration prometheus
+        # counters are initialised for the auth_provider_ids.
+        _load_sso_handlers(hs)
+
     def on_GET(self, request: SynapseRequest):
         flows = []
         if self.jwt_enabled:
@@ -499,12 +505,7 @@ class SsoRedirectServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         # make sure that the relevant handlers are instantiated, so that they
         # register themselves with the main SSOHandler.
-        if hs.config.cas_enabled:
-            hs.get_cas_handler()
-        if hs.config.saml2_enabled:
-            hs.get_saml_handler()
-        if hs.config.oidc_enabled:
-            hs.get_oidc_handler()
+        _load_sso_handlers(hs)
         self._sso_handler = hs.get_sso_handler()
         self._msc2858_enabled = hs.config.experimental.msc2858_enabled
         self._public_baseurl = hs.config.public_baseurl
@@ -598,3 +599,19 @@ def register_servlets(hs, http_server):
     SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
         CasTicketServlet(hs).register(http_server)
+
+
+def _load_sso_handlers(hs: "HomeServer"):
+    """Ensure that the SSO handlers are loaded, if they are enabled by configuration.
+
+    This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
+    with the main SsoHandler.
+
+    It's safe to call this multiple times.
+    """
+    if hs.config.cas.cas_enabled:
+        hs.get_cas_handler()
+    if hs.config.saml2.saml2_enabled:
+        hs.get_saml_handler()
+    if hs.config.oidc.oidc_enabled:
+        hs.get_oidc_handler()
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 84619c5e41..98604a9388 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -13,17 +13,23 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, StoreError, SynapseError
-from synapse.http.server import respond_with_html_bytes
+from synapse.http.server import HttpServer, respond_with_html_bytes
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.push import PusherConfigException
 from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,12 +37,12 @@ logger = logging.getLogger(__name__)
 class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         user = requester.user
 
@@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet):
 class PushersSetRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers/set$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
         self.pusher_pool = self.hs.get_pusherpool()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         user = requester.user
 
@@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers/remove$", v1=True)
     SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.notifier = hs.get_notifier()
         self.auth = hs.get_auth()
         self.pusher_pool = self.hs.get_pusherpool()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> None:
         requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
         user = requester.user
 
@@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet):
         return None
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     PushersRestServlet(hs).register(http_server)
     PushersSetRestServlet(hs).register(http_server)
     PushersRemoveRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 027f8b81fa..43c04fac6f 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -13,27 +13,36 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.constants import ReadReceiptEventFields
 from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class ReadMarkerRestServlet(RestServlet):
     PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
         await self.presence_handler.bump_presence_active_time(requester.user)
@@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReadMarkerRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 58b8e8f261..2781a0ea96 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
     UnrecognizedRequestError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
 from synapse.config.captcha import CaptchaConfig
 from synapse.config.consent import ConsentConfig
@@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet):
             return 200, {"available": True}
 
 
+class RegistrationTokenValidityRestServlet(RestServlet):
+    """Check the validity of a registration token.
+
+    Example:
+
+        GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd
+
+        200 OK
+
+        {
+            "valid": true
+        }
+    """
+
+    PATTERNS = client_patterns(
+        f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity",
+        releases=(),
+        unstable=True,
+    )
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super().__init__()
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.ratelimiter = Ratelimiter(
+            store=self.store,
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second,
+            burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
+        )
+
+    async def on_GET(self, request):
+        await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
+
+        if not self.hs.config.enable_registration:
+            raise SynapseError(
+                403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+            )
+
+        token = parse_string(request, "token", required=True)
+        valid = await self.store.registration_token_is_valid(token)
+
+        return 200, {"valid": valid}
+
+
 class RegisterRestServlet(RestServlet):
     PATTERNS = client_patterns("/register$")
 
@@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet):
         )
 
         if registered:
+            # Check if a token was used to authenticate registration
+            registration_token = await self.auth_handler.get_session_data(
+                session_id,
+                UIAuthSessionDataConstants.REGISTRATION_TOKEN,
+            )
+            if registration_token:
+                # Increment the `completed` counter for the token
+                await self.store.use_registration_token(registration_token)
+                # Indicate that the token has been successfully used so that
+                # pending is not decremented again when expiring old UIA sessions.
+                await self.store.mark_ui_auth_stage_complete(
+                    session_id,
+                    LoginType.REGISTRATION_TOKEN,
+                    True,
+                )
+
             await self.registration_handler.post_registration_actions(
                 user_id=registered_user_id,
                 auth_result=auth_result,
@@ -868,6 +934,11 @@ def _calculate_registration_flows(
         for flow in flows:
             flow.insert(0, LoginType.RECAPTCHA)
 
+    # Prepend registration token to all flows if we're requiring a token
+    if config.registration_requires_token:
+        for flow in flows:
+            flow.insert(0, LoginType.REGISTRATION_TOKEN)
+
     return flows
 
 
@@ -876,4 +947,5 @@ def register_servlets(hs, http_server):
     MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
     UsernameAvailabilityRestServlet(hs).register(http_server)
     RegistrationSubmitTokenServlet(hs).register(http_server)
+    RegistrationTokenValidityRestServlet(hs).register(http_server)
     RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 6d1b083acb..6a7792e18b 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -13,18 +13,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, ShadowBanError, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_json_object_from_request,
 )
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 from synapse.util import stringutils
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet):
         }
 
     Creates a new room and shuts down the old one. Returns the ID of the new room.
-
-    Args:
-        hs (synapse.server.HomeServer):
     """
 
     PATTERNS = client_patterns(
@@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet):
         "/rooms/(?P<room_id>[^/]*)/upgrade$"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self._hs = hs
         self._room_creation_handler = hs.get_room_creation_handler()
         self._auth = hs.get_auth()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self._auth.get_user_by_req(request)
 
         content = parse_json_object_from_request(request)
@@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet):
         return 200, ret
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomUpgradeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index d2e7f04b40..1d90493eb0 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -12,13 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
-from synapse.types import UserID
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict, UserID
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet):
         releases=(),  # This is an unstable feature
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.user_directory_active = hs.config.update_user_directory
 
-    async def on_GET(self, request, user_id):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
 
         if not self.user_directory_active:
             raise SynapseError(
@@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet):
         return 200, {"joined": list(rooms)}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e18f4d01b3..65c37be3e9 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,17 +14,26 @@
 import itertools
 import logging
 from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
 
 from synapse.api.constants import Membership, PresenceState
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.api.presence import UserPresenceState
 from synapse.events.utils import (
     format_event_for_client_v2_without_room_id,
     format_event_raw,
 )
 from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import KnockedSyncResult, SyncConfig
+from synapse.handlers.sync import (
+    ArchivedSyncResult,
+    InvitedSyncResult,
+    JoinedSyncResult,
+    KnockedSyncResult,
+    SyncConfig,
+    SyncResult,
+)
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, StreamToken
@@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
             return 200, {}
 
         time_now = self.clock.time_msec()
+        # We know that the the requester has an access token since appservices
+        # cannot use sync.
         response_content = await self.encode_response(
             time_now, sync_result, requester.access_token_id, filter_collection
         )
@@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
         logger.debug("Event formatting complete")
         return 200, response_content
 
-    async def encode_response(self, time_now, sync_result, access_token_id, filter):
+    async def encode_response(
+        self,
+        time_now: int,
+        sync_result: SyncResult,
+        access_token_id: Optional[int],
+        filter: FilterCollection,
+    ) -> JsonDict:
         logger.debug("Formatting events in sync response")
         if filter.event_format == "client":
             event_formatter = format_event_for_client_v2_without_room_id
@@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
 
         logger.debug("building sync response dict")
 
-        response: dict = defaultdict(dict)
+        response: JsonDict = defaultdict(dict)
         response["next_batch"] = await sync_result.next_batch.to_string(self.store)
 
         if sync_result.account_data:
@@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
         if archived:
             response["rooms"][Membership.LEAVE] = archived
 
+        # By the time we get here groups is no longer optional.
+        assert sync_result.groups is not None
         if sync_result.groups.join:
             response["groups"][Membership.JOIN] = sync_result.groups.join
         if sync_result.groups.invite:
@@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
         return response
 
     @staticmethod
-    def encode_presence(events, time_now):
+    def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
         return {
             "events": [
                 {
@@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
         }
 
     async def encode_joined(
-        self, rooms, time_now, token_id, event_fields, event_formatter
-    ):
+        self,
+        rooms: List[JoinedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        event_fields: List[str],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Encode the joined rooms in a sync result
 
         Args:
-            rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
-                results for rooms this user is joined to
-            time_now(int): current time - used as a baseline for age
-                calculations
-            token_id(int): ID of the user's auth token - used for namespacing
+            rooms: list of sync results for rooms this user is joined to
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            event_fields(list<str>): List of event fields to include. If empty,
+            event_fields: List of event fields to include. If empty,
                 all fields will be returned.
-            event_formatter (func[dict]): function to convert from federation format
+            event_formatter: function to convert from federation format
                 to client format
         Returns:
-            dict[str, dict[str, object]]: the joined rooms list, in our
-                response format
+            The joined rooms list, in our response format
         """
         joined = {}
         for room in rooms:
@@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
 
         return joined
 
-    async def encode_invited(self, rooms, time_now, token_id, event_formatter):
+    async def encode_invited(
+        self,
+        rooms: List[InvitedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Encode the invited rooms in a sync result
 
         Args:
-            rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
-                sync results for rooms this user is invited to
-            time_now(int): current time - used as a baseline for age
-                calculations
-            token_id(int): ID of the user's auth token - used for namespacing
+            rooms: list of sync results for rooms this user is invited to
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            event_formatter (func[dict]): function to convert from federation format
+            event_formatter: function to convert from federation format
                 to client format
 
         Returns:
-            dict[str, dict[str, object]]: the invited rooms list, in our
-                response format
+            The invited rooms list, in our response format
         """
         invited = {}
         for room in rooms:
@@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
         self,
         rooms: List[KnockedSyncResult],
         time_now: int,
-        token_id: int,
+        token_id: Optional[int],
         event_formatter: Callable[[Dict], Dict],
     ) -> Dict[str, Dict[str, Any]]:
         """
@@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
         return knocked
 
     async def encode_archived(
-        self, rooms, time_now, token_id, event_fields, event_formatter
-    ):
+        self,
+        rooms: List[ArchivedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        event_fields: List[str],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Encode the archived rooms in a sync result
 
         Args:
-            rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
-                sync results for rooms this user is joined to
-            time_now(int): current time - used as a baseline for age
-                calculations
-            token_id(int): ID of the user's auth token - used for namespacing
+            rooms: list of sync results for rooms this user is joined to
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            event_fields(list<str>): List of event fields to include. If empty,
+            event_fields: List of event fields to include. If empty,
                 all fields will be returned.
-            event_formatter (func[dict]): function to convert from federation format
-                to client format
+            event_formatter: function to convert from federation format to client format
         Returns:
-            dict[str, dict[str, object]]: The invited rooms list, in our
-                response format
+            The archived rooms list, in our response format
         """
         joined = {}
         for room in rooms:
@@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
         return joined
 
     async def encode_room(
-        self, room, time_now, token_id, joined, only_fields, event_formatter
-    ):
+        self,
+        room: Union[JoinedSyncResult, ArchivedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        joined: bool,
+        only_fields: Optional[List[str]],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Args:
-            room (JoinedSyncResult|ArchivedSyncResult): sync result for a
-                single room
-            time_now (int): current time - used as a baseline for age
-                calculations
-            token_id (int): ID of the user's auth token - used for namespacing
+            room: sync result for a single room
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            joined (bool): True if the user is joined to this room - will mean
+            joined: True if the user is joined to this room - will mean
                 we handle ephemeral events
-            only_fields(list<str>): Optional. The list of event fields to include.
-            event_formatter (func[dict]): function to convert from federation format
+            only_fields: Optional. The list of event fields to include.
+            event_formatter: function to convert from federation format
                 to client format
         Returns:
-            dict[str, object]: the room, encoded in our response format
+            The room, encoded in our response format
         """
 
         def serialize(events):
@@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
 
         account_data = room.account_data
 
-        result = {
+        result: JsonDict = {
             "timeline": {
                 "events": serialized_timeline,
                 "prev_batch": await room.timeline.prev_batch.to_string(self.store),
@@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
         }
 
         if joined:
+            assert isinstance(room, JoinedSyncResult)
             ephemeral_events = room.ephemeral
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
@@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
         return result
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index c14f83be18..c88cb9367c 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -13,12 +13,19 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -29,12 +36,14 @@ class TagListServlet(RestServlet):
 
     PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, user_id, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot get tags for other users.")
@@ -54,12 +63,14 @@ class TagServlet(RestServlet):
         "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
         self.handler = hs.get_account_data_handler()
 
-    async def on_PUT(self, request, user_id, room_id, tag):
+    async def on_PUT(
+        self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
@@ -70,7 +81,9 @@ class TagServlet(RestServlet):
 
         return 200, {}
 
-    async def on_DELETE(self, request, user_id, room_id, tag):
+    async def on_DELETE(
+        self, request: SynapseRequest, user_id: str, room_id: str, tag: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
@@ -80,6 +93,6 @@ class TagServlet(RestServlet):
         return 200, {}
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     TagListServlet(hs).register(http_server)
     TagServlet(hs).register(http_server)
diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py
index b5c67c9bb6..b895c73acf 100644
--- a/synapse/rest/client/thirdparty.py
+++ b/synapse/rest/client/thirdparty.py
@@ -12,27 +12,33 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
 
 from synapse.api.constants import ThirdPartyEntityKind
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class ThirdPartyProtocolsServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/protocols")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = await self.appservice_handler.get_3pe_protocols()
@@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet):
 class ThirdPartyProtocolServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
 
-    async def on_GET(self, request, protocol):
+    async def on_GET(
+        self, request: SynapseRequest, protocol: str
+    ) -> Tuple[int, JsonDict]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = await self.appservice_handler.get_3pe_protocols(
@@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet):
 class ThirdPartyUserServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
 
-    async def on_GET(self, request, protocol):
+    async def on_GET(
+        self, request: SynapseRequest, protocol: str
+    ) -> Tuple[int, List[JsonDict]]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
-        fields = request.args
+        fields: Dict[bytes, List[bytes]] = request.args  # type: ignore[assignment]
         fields.pop(b"access_token", None)
 
         results = await self.appservice_handler.query_3pe(
@@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet):
 class ThirdPartyLocationServlet(RestServlet):
     PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
         self.auth = hs.get_auth()
         self.appservice_handler = hs.get_application_service_handler()
 
-    async def on_GET(self, request, protocol):
+    async def on_GET(
+        self, request: SynapseRequest, protocol: str
+    ) -> Tuple[int, List[JsonDict]]:
         await self.auth.get_user_by_req(request, allow_guest=True)
 
-        fields = request.args
+        fields: Dict[bytes, List[bytes]] = request.args  # type: ignore[assignment]
         fields.pop(b"access_token", None)
 
         results = await self.appservice_handler.query_3pe(
@@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet):
         return 200, results
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ThirdPartyProtocolsServlet(hs).register(http_server)
     ThirdPartyProtocolServlet(hs).register(http_server)
     ThirdPartyUserServlet(hs).register(http_server)
diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py
index b2f858545c..c8c3b25bd3 100644
--- a/synapse/rest/client/tokenrefresh.py
+++ b/synapse/rest/client/tokenrefresh.py
@@ -12,11 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import TYPE_CHECKING
+
+from twisted.web.server import Request
+
 from synapse.api.errors import AuthError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class TokenRefreshRestServlet(RestServlet):
     """
@@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet):
 
     PATTERNS = client_patterns("/tokenrefresh")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> None:
         raise AuthError(403, "tokenrefresh is no longer supported.")
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     TokenRefreshRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index 7e8912f0b9..8852811114 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -13,29 +13,32 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class UserDirectorySearchRestServlet(RestServlet):
     PATTERNS = client_patterns("/user_directory/search$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.user_directory_handler = hs.get_user_directory_handler()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """Searches for users in directory
 
         Returns:
@@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet):
         return 200, results
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index fa2e4e9cba..a1a815cf82 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -17,9 +17,17 @@
 
 import logging
 import re
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
 
 from synapse.api.constants import RoomCreationPreset
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -27,7 +35,7 @@ logger = logging.getLogger(__name__)
 class VersionsRestServlet(RestServlet):
     PATTERNS = [re.compile("^/_matrix/client/versions$")]
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.config = hs.config
 
@@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet):
             in self.config.encryption_enabled_by_default_for_room_presets
         )
 
-    def on_GET(self, request):
+    def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
         return (
             200,
             {
@@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     VersionsRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py
index f53020520d..9d46ed3af3 100644
--- a/synapse/rest/client/voip.py
+++ b/synapse/rest/client/voip.py
@@ -15,20 +15,27 @@
 import base64
 import hashlib
 import hmac
+from typing import TYPE_CHECKING, Tuple
 
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 
 class VoipRestServlet(RestServlet):
     PATTERNS = client_patterns("/voip/turnServer$", v1=True)
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(
             request, self.hs.config.turn_allow_guests
         )
@@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     VoipRestServlet(hs).register(http_server)
diff --git a/synapse/static/client/register/style.css b/synapse/static/client/register/style.css
index 5a7b6eebf2..8a39b5d0f5 100644
--- a/synapse/static/client/register/style.css
+++ b/synapse/static/client/register/style.css
@@ -57,4 +57,8 @@ textarea, input {
     
     background-color: #f8f8f8;
     border: 1px #ccc solid;
-}
\ No newline at end of file
+}
+
+.error {
+	color: red;
+}
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 01b918e12e..00a644e8f7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,6 +63,7 @@ from .relations import RelationsStore
 from .room import RoomStore
 from .roommember import RoomMemberStore
 from .search import SearchStore
+from .session import SessionStore
 from .signatures import SignatureStore
 from .state import StateStore
 from .stats import StatsStore
@@ -121,6 +122,7 @@ class DataStore(
     ServerMetricsStore,
     EventForwardExtremitiesStore,
     LockStore,
+    SessionStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self.hs = hs
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81c6..a6517962f6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
         return user_id
 
-    def get_user_id_by_threepid_txn(self, txn, medium, address):
+    def get_user_id_by_threepid_txn(
+        self, txn, medium: str, address: str
+    ) -> Optional[str]:
         """Returns user id from threepid
 
         Args:
             txn (cursor):
-            medium (str): threepid medium e.g. email
-            address (str): threepid address e.g. me@example.com
+            medium: threepid medium e.g. email
+            address: threepid address e.g. me@example.com
 
         Returns:
-            str|None: user id or None if no user id/threepid mapping exists
+            user id, or None if no user id/threepid mapping exists
         """
         ret = self.db_pool.simple_select_one_txn(
             txn,
@@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             return ret["user_id"]
         return None
 
-    async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+    async def user_add_threepid(
+        self,
+        user_id: str,
+        medium: str,
+        address: str,
+        validated_at: int,
+        added_at: int,
+    ) -> None:
         await self.db_pool.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    async def user_get_threepids(self, user_id):
+    async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
         return await self.db_pool.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
@@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "user_get_threepids",
         )
 
-    async def user_delete_threepid(self, user_id, medium, address) -> None:
+    async def user_delete_threepid(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
         await self.db_pool.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
@@ -1157,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="update_access_token_last_validated",
         )
 
+    async def registration_token_is_valid(self, token: str) -> bool:
+        """Checks if a token can be used to authenticate a registration.
+
+        Args:
+            token: The registration token to be checked
+        Returns:
+            True if the token is valid, False otherwise.
+        """
+        res = await self.db_pool.simple_select_one(
+            "registration_tokens",
+            keyvalues={"token": token},
+            retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+            allow_none=True,
+        )
+
+        # Check if the token exists
+        if res is None:
+            return False
+
+        # Check if the token has expired
+        now = self._clock.time_msec()
+        if res["expiry_time"] and res["expiry_time"] < now:
+            return False
+
+        # Check if the token has been used up
+        if (
+            res["uses_allowed"]
+            and res["pending"] + res["completed"] >= res["uses_allowed"]
+        ):
+            return False
+
+        # Otherwise, the token is valid
+        return True
+
+    async def set_registration_token_pending(self, token: str) -> None:
+        """Increment the pending registrations counter for a token.
+
+        Args:
+            token: The registration token pending use
+        """
+
+        def _set_registration_token_pending_txn(txn):
+            pending = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"pending": pending + 1},
+            )
+
+        return await self.db_pool.runInteraction(
+            "set_registration_token_pending", _set_registration_token_pending_txn
+        )
+
+    async def use_registration_token(self, token: str) -> None:
+        """Complete a use of the given registration token.
+
+        The `pending` counter will be decremented, and the `completed`
+        counter will be incremented.
+
+        Args:
+            token: The registration token to be 'used'
+        """
+
+        def _use_registration_token_txn(txn):
+            # Normally, res is Optional[Dict[str, Any]].
+            # Override type because the return type is only optional if
+            # allow_none is True, and we don't want mypy throwing errors
+            # about None not being indexable.
+            res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )  # type: ignore
+
+            # Decrement pending and increment completed
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={
+                    "completed": res["completed"] + 1,
+                    "pending": res["pending"] - 1,
+                },
+            )
+
+        return await self.db_pool.runInteraction(
+            "use_registration_token", _use_registration_token_txn
+        )
+
+    async def get_registration_tokens(
+        self, valid: Optional[bool] = None
+    ) -> List[Dict[str, Any]]:
+        """List all registration tokens. Used by the admin API.
+
+        Args:
+            valid: If True, only valid tokens are returned.
+              If False, only invalid tokens are returned.
+              Default is None: return all tokens regardless of validity.
+
+        Returns:
+            A list of dicts, each containing details of a token.
+        """
+
+        def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+            if valid is None:
+                # Return all tokens regardless of validity
+                txn.execute("SELECT * FROM registration_tokens")
+
+            elif valid:
+                # Select valid tokens only
+                sql = (
+                    "SELECT * FROM registration_tokens WHERE "
+                    "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+                    "AND (expiry_time > ? OR expiry_time IS NULL)"
+                )
+                txn.execute(sql, [now])
+
+            else:
+                # Select invalid tokens only
+                sql = (
+                    "SELECT * FROM registration_tokens WHERE "
+                    "uses_allowed <= pending + completed OR expiry_time <= ?"
+                )
+                txn.execute(sql, [now])
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "select_registration_tokens",
+            select_registration_tokens_txn,
+            self._clock.time_msec(),
+            valid,
+        )
+
+    async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+        """Get info about the given registration token. Used by the admin API.
+
+        Args:
+            token: The token to retrieve information about.
+
+        Returns:
+            A dict, or None if token doesn't exist.
+        """
+        return await self.db_pool.simple_select_one(
+            "registration_tokens",
+            keyvalues={"token": token},
+            retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+            allow_none=True,
+            desc="get_one_registration_token",
+        )
+
+    async def generate_registration_token(
+        self, length: int, chars: str
+    ) -> Optional[str]:
+        """Generate a random registration token. Used by the admin API.
+
+        Args:
+            length: The length of the token to generate.
+            chars: A string of the characters allowed in the generated token.
+
+        Returns:
+            The generated token.
+
+        Raises:
+            SynapseError if a unique registration token could still not be
+            generated after a few tries.
+        """
+        # Make a few attempts at generating a unique token of the required
+        # length before failing.
+        for _i in range(3):
+            # Generate token
+            token = "".join(random.choices(chars, k=length))
+
+            # Check if the token already exists
+            existing_token = await self.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="token",
+                allow_none=True,
+                desc="check_if_registration_token_exists",
+            )
+
+            if existing_token is None:
+                # The generated token doesn't exist yet, return it
+                return token
+
+        raise SynapseError(
+            500,
+            "Unable to generate a unique registration token. Try again with a greater length",
+            Codes.UNKNOWN,
+        )
+
+    async def create_registration_token(
+        self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+    ) -> bool:
+        """Create a new registration token. Used by the admin API.
+
+        Args:
+            token: The token to create.
+            uses_allowed: The number of times the token can be used to complete
+              a registration before it becomes invalid. A value of None indicates
+              unlimited uses.
+            expiry_time: The latest time the token is valid. Given as the
+              number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+              None indicates that the token does not expire.
+
+        Returns:
+            Whether the row was inserted or not.
+        """
+
+        def _create_registration_token_txn(txn):
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["token"],
+                allow_none=True,
+            )
+
+            if row is not None:
+                # Token already exists
+                return False
+
+            self.db_pool.simple_insert_txn(
+                txn,
+                "registration_tokens",
+                values={
+                    "token": token,
+                    "uses_allowed": uses_allowed,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": expiry_time,
+                },
+            )
+
+            return True
+
+        return await self.db_pool.runInteraction(
+            "create_registration_token", _create_registration_token_txn
+        )
+
+    async def update_registration_token(
+        self, token: str, updatevalues: Dict[str, Optional[int]]
+    ) -> Optional[Dict[str, Any]]:
+        """Update a registration token. Used by the admin API.
+
+        Args:
+            token: The token to update.
+            updatevalues: A dict with the fields to update. E.g.:
+              `{"uses_allowed": 3}` to update just uses_allowed, or
+              `{"uses_allowed": 3, "expiry_time": None}` to update both.
+              This is passed straight to simple_update_one.
+
+        Returns:
+            A dict with all info about the token, or None if token doesn't exist.
+        """
+
+        def _update_registration_token_txn(txn):
+            try:
+                self.db_pool.simple_update_one_txn(
+                    txn,
+                    "registration_tokens",
+                    keyvalues={"token": token},
+                    updatevalues=updatevalues,
+                )
+            except StoreError:
+                # Update failed because token does not exist
+                return None
+
+            # Get all info about the token so it can be sent in the response
+            return self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=[
+                    "token",
+                    "uses_allowed",
+                    "pending",
+                    "completed",
+                    "expiry_time",
+                ],
+                allow_none=True,
+            )
+
+        return await self.db_pool.runInteraction(
+            "update_registration_token", _update_registration_token_txn
+        )
+
+    async def delete_registration_token(self, token: str) -> bool:
+        """Delete a registration token. Used by the admin API.
+
+        Args:
+            token: The token to delete.
+
+        Returns:
+            Whether the token was successfully deleted or not.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                desc="delete_registration_token",
+            )
+        except StoreError:
+            # Deletion failed because token does not exist
+            return False
+
+        return True
+
     @cached()
     async def mark_access_token_as_used(self, token_id: int) -> None:
         """
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e8157ba3d4..c58a4b8690 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached()
-    async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
+    async def get_invited_rooms_for_local_user(
+        self, user_id: str
+    ) -> List[RoomsForUser]:
         """Get all the rooms the *local* user is invited to.
 
         Args:
@@ -384,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
         sql = """
-            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
             FROM local_current_membership AS c
             INNER JOIN events AS e USING (room_id, event_id)
+            INNER JOIN rooms AS r USING (room_id)
             WHERE
                 user_id = ?
                 AND %s
@@ -395,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
         txn.execute(sql, (user_id, *args))
-        results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
+        results = [RoomsForUser(*r) for r in txn]
 
         return results
 
@@ -445,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns:
             Returns the rooms the user is in currently, along with the stream
-            ordering of the most recent join for that user and room.
+            ordering of the most recent join for that user and room, along with
+            the room version of the room.
         """
         return await self.db_pool.runInteraction(
             "get_rooms_for_user_with_stream_ordering",
@@ -522,7 +526,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             _get_users_server_still_shares_room_with_txn,
         )
 
-    async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
+    async def get_rooms_for_user(
+        self, user_id: str, on_invalidate=None
+    ) -> FrozenSet[str]:
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
new file mode 100644
index 0000000000..172f27d109
--- /dev/null
+++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from typing import TYPE_CHECKING
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class SessionStore(SQLBaseStore):
+    """
+    A store for generic session data.
+
+    Each type of session should provide a unique type (to separate sessions).
+
+    Sessions are automatically removed when they expire.
+    """
+
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        # Create a background job for culling expired sessions.
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
+
+    async def create_session(
+        self, session_type: str, value: JsonDict, expiry_ms: int
+    ) -> str:
+        """
+        Creates a new pagination session for the room hierarchy endpoint.
+
+        Args:
+            session_type: The type for this session.
+            value: The value to store.
+            expiry_ms: How long before an item is evicted from the cache
+                in milliseconds. Default is 0, indicating items never get
+                evicted based on time.
+
+        Returns:
+            The newly created session ID.
+
+        Raises:
+            StoreError if a unique session ID cannot be generated.
+        """
+        # autogen a session ID and try to create it. We may clash, so just
+        # try a few times till one goes through, giving up eventually.
+        attempts = 0
+        while attempts < 5:
+            session_id = stringutils.random_string(24)
+
+            try:
+                await self.db_pool.simple_insert(
+                    table="sessions",
+                    values={
+                        "session_id": session_id,
+                        "session_type": session_type,
+                        "value": json_encoder.encode(value),
+                        "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
+                    },
+                    desc="create_session",
+                )
+
+                return session_id
+            except self.db_pool.engine.module.IntegrityError:
+                attempts += 1
+        raise StoreError(500, "Couldn't generate a session ID.")
+
+    async def get_session(self, session_type: str, session_id: str) -> JsonDict:
+        """
+        Retrieve data stored with create_session
+
+        Args:
+            session_type: The type for this session.
+            session_id: The session ID returned from create_session.
+
+        Raises:
+            StoreError if the session cannot be found.
+        """
+
+        def _get_session(
+            txn: LoggingTransaction, session_type: str, session_id: str, ts: int
+        ) -> JsonDict:
+            # This includes the expiry time since items are only periodically
+            # deleted, not upon expiry.
+            select_sql = """
+            SELECT value FROM sessions WHERE
+            session_type = ? AND session_id = ? AND expiry_time_ms > ?
+            """
+            txn.execute(select_sql, [session_type, session_id, ts])
+            row = txn.fetchone()
+
+            if not row:
+                raise StoreError(404, "No session")
+
+            return db_to_json(row[0])
+
+        return await self.db_pool.runInteraction(
+            "get_session",
+            _get_session,
+            session_type,
+            session_id,
+            self._clock.time_msec(),
+        )
+
+    @wrap_as_background_process("delete_expired_sessions")
+    async def _delete_expired_sessions(self) -> None:
+        """Remove sessions with expiry dates that have passed."""
+
+        def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
+            sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "delete_expired_sessions",
+            _delete_expired_sessions_txn,
+            self._clock.time_msec(),
+        )
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
 
+from synapse.api.constants import LoginType
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
             keyvalues={},
         )
 
+        # If a registration token was used, decrement the pending counter
+        # before deleting the session.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="ui_auth_sessions_credentials",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+            retcols=["result"],
+        )
+
+        # Get the tokens used and how much pending needs to be decremented by.
+        token_counts: Dict[str, int] = {}
+        for r in rows:
+            # If registration was successfully completed, the result of the
+            # registration token stage for that session will be True.
+            # If a token was used to authenticate, but registration was
+            # never completed, the result will be the token used.
+            token = db_to_json(r["result"])
+            if isinstance(token, str):
+                token_counts[token] = token_counts.get(token, 0) + 1
+
+        # Update the `pending` counters.
+        if len(token_counts) > 0:
+            token_rows = self.db_pool.simple_select_many_txn(
+                txn,
+                table="registration_tokens",
+                column="token",
+                iterable=list(token_counts.keys()),
+                keyvalues={},
+                retcols=["token", "pending"],
+            )
+            for token_row in token_rows:
+                token = token_row["token"]
+                new_pending = token_row["pending"] - token_counts[token]
+                self.db_pool.simple_update_one_txn(
+                    txn,
+                    table="registration_tokens",
+                    keyvalues={"token": token},
+                    updatevalues={"pending": new_pending},
+                )
+
         # Delete the corresponding completed credentials.
         self.db_pool.simple_delete_many_txn(
             txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 9d28d69ac7..65dde67ae9 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         return False
 
     async def update_profile_in_user_dir(
-        self, user_id: str, display_name: str, avatar_url: str
+        self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
     ) -> None:
         """
         Update or add a user's profile in the user directory.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index c34fbf21bc..2500381b7b 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -14,25 +14,41 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
+from typing import List, Optional, Tuple
+
+import attr
+
+from synapse.types import PersistedEventPosition
 
 logger = logging.getLogger(__name__)
 
 
-RoomsForUser = namedtuple(
-    "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class RoomsForUser:
+    room_id: str
+    sender: str
+    membership: str
+    event_id: str
+    stream_ordering: int
+    room_version_id: str
+
+
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class GetRoomsForUserWithStreamOrdering:
+    room_id: str
+    event_pos: PersistedEventPosition
 
-GetRoomsForUserWithStreamOrdering = namedtuple(
-    "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
-)
 
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class ProfileInfo:
+    avatar_url: Optional[str]
+    display_name: Optional[str]
 
-# We store this using a namedtuple so that we save about 3x space over using a
-# dict.
-ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
 
-# "members" points to a truncated list of (user_id, event_id) tuples for users of
-# a given membership type, suitable for use in calculating heroes for a room.
-# "count" points to the total numberr of users of a given membership type.
-MemberSummary = namedtuple("MemberSummary", ("members", "count"))
+@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True)
+class MemberSummary:
+    # A truncated list of (user_id, event_id) tuples for users of a given
+    # membership type, suitable for use in calculating heroes for a room.
+    members: List[Tuple[str, str]]
+    # The total number of users of a given membership type.
+    count: int
diff --git a/synapse/storage/schema/main/delta/62/02session_store.sql b/synapse/storage/schema/main/delta/62/02session_store.sql
new file mode 100644
index 0000000000..535fb34c10
--- /dev/null
+++ b/synapse/storage/schema/main/delta/62/02session_store.sql
@@ -0,0 +1,23 @@
+/*
+ * Copyright 2021 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS sessions(
+    session_type TEXT NOT NULL,  -- The unique key for this type of session.
+    session_id TEXT NOT NULL,  -- The session ID passed to the client.
+    value TEXT NOT NULL, -- A JSON dictionary to persist.
+    expiry_time_ms BIGINT NOT NULL,  -- The time this session will expire (epoch time in milliseconds).
+    UNIQUE (session_type, session_id)
+);
diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
new file mode 100644
index 0000000000..ee6cf958f4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql
@@ -0,0 +1,23 @@
+/* Copyright 2021 Callum Brown
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS registration_tokens(
+    token TEXT NOT NULL,  -- The token that can be used for authentication.
+    uses_allowed INT,  -- The total number of times this token can be used. NULL if no limit.
+    pending INT NOT NULL, -- The number of in progress registrations using this token.
+    completed INT NOT NULL, -- The number of times this token has been used to complete a registration.
+    expiry_time BIGINT,  -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire.
+    UNIQUE (token)
+);
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 6b87f571b8..3b3866bff8 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
 import attr
 
 from synapse.api.constants import EduTypes
-from synapse.events.presence_router import PresenceRouter
+from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
 from synapse.federation.units import Transaction
 from synapse.handlers.presence import UserPresenceState
 from synapse.module_api import ModuleApi
@@ -34,7 +34,7 @@ class PresenceRouterTestConfig:
     users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
 
 
-class PresenceRouterTestModule:
+class LegacyPresenceRouterTestModule:
     def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
         self._config = config
         self._module_api = module_api
@@ -77,6 +77,53 @@ class PresenceRouterTestModule:
         return config
 
 
+class PresenceRouterTestModule:
+    def __init__(self, config: PresenceRouterTestConfig, api: ModuleApi):
+        self._config = config
+        self._module_api = api
+        api.register_presence_router_callbacks(
+            get_users_for_states=self.get_users_for_states,
+            get_interested_users=self.get_interested_users,
+        )
+
+    async def get_users_for_states(
+        self, state_updates: Iterable[UserPresenceState]
+    ) -> Dict[str, Set[UserPresenceState]]:
+        users_to_state = {
+            user_id: set(state_updates)
+            for user_id in self._config.users_who_should_receive_all_presence
+        }
+        return users_to_state
+
+    async def get_interested_users(
+        self, user_id: str
+    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+        if user_id in self._config.users_who_should_receive_all_presence:
+            return PresenceRouter.ALL_USERS
+
+        return set()
+
+    @staticmethod
+    def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+        """Parse a configuration dictionary from the homeserver config, do
+        some validation and return a typed PresenceRouterConfig.
+
+        Args:
+            config_dict: The configuration dictionary.
+
+        Returns:
+            A validated config object.
+        """
+        # Initialise a typed config object
+        config = PresenceRouterTestConfig()
+
+        config.users_who_should_receive_all_presence = config_dict.get(
+            "users_who_should_receive_all_presence"
+        )
+
+        return config
+
+
 class PresenceRouterTestCase(FederatingHomeserverTestCase):
     servlets = [
         admin.register_servlets,
@@ -86,9 +133,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor, clock):
-        return self.setup_test_homeserver(
+        hs = self.setup_test_homeserver(
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
+        # Load the modules into the homeserver
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+
+        load_legacy_presence_router(hs)
+
+        return hs
 
     def prepare(self, reactor, clock, homeserver):
         self.sync_handler = self.hs.get_sync_handler()
@@ -98,7 +153,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         {
             "presence": {
                 "presence_router": {
-                    "module": __name__ + ".PresenceRouterTestModule",
+                    "module": __name__ + ".LegacyPresenceRouterTestModule",
                     "config": {
                         "users_who_should_receive_all_presence": [
                             "@presence_gobbler:test",
@@ -109,7 +164,28 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             "send_federation": True,
         }
     )
+    def test_receiving_all_presence_legacy(self):
+        self.receiving_all_presence_test_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".PresenceRouterTestModule",
+                    "config": {
+                        "users_who_should_receive_all_presence": [
+                            "@presence_gobbler:test",
+                        ]
+                    },
+                },
+            ],
+            "send_federation": True,
+        }
+    )
     def test_receiving_all_presence(self):
+        self.receiving_all_presence_test_body()
+
+    def receiving_all_presence_test_body(self):
         """Test that a user that does not share a room with another other can receive
         presence for them, due to presence routing.
         """
@@ -203,7 +279,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         {
             "presence": {
                 "presence_router": {
-                    "module": __name__ + ".PresenceRouterTestModule",
+                    "module": __name__ + ".LegacyPresenceRouterTestModule",
                     "config": {
                         "users_who_should_receive_all_presence": [
                             "@presence_gobbler1:test",
@@ -216,7 +292,30 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             "send_federation": True,
         }
     )
+    def test_send_local_online_presence_to_with_module_legacy(self):
+        self.send_local_online_presence_to_with_module_test_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".PresenceRouterTestModule",
+                    "config": {
+                        "users_who_should_receive_all_presence": [
+                            "@presence_gobbler1:test",
+                            "@presence_gobbler2:test",
+                            "@far_away_person:island",
+                        ]
+                    },
+                },
+            ],
+            "send_federation": True,
+        }
+    )
     def test_send_local_online_presence_to_with_module(self):
+        self.send_local_online_presence_to_with_module_test_body()
+
+    def send_local_online_presence_to_with_module_test_body(self):
         """Tests that send_local_presence_to_users sends local online presence to a set
         of specified local and remote users, with a custom PresenceRouter module enabled.
         """
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 84f05f6c58..339c039914 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,9 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Optional
+
+from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.room_versions import RoomVersions
 from synapse.handlers.sync import SyncConfig
+from synapse.rest import admin
+from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
 
 import tests.unittest
@@ -24,8 +31,14 @@ import tests.utils
 class SyncTestCase(tests.unittest.HomeserverTestCase):
     """Tests Sync Handler."""
 
-    def prepare(self, reactor, clock, hs):
-        self.hs = hs
+    servlets = [
+        admin.register_servlets,
+        knock.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs: HomeServer):
         self.sync_handler = self.hs.get_sync_handler()
         self.store = self.hs.get_datastore()
 
@@ -68,12 +81,124 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         )
         self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
+    def test_unknown_room_version(self):
+        """
+        A room with an unknown room version should not break sync (and should be excluded).
+        """
+        inviter = self.register_user("creator", "pass", admin=True)
+        inviter_tok = self.login("@creator:test", "pass")
+
+        user = self.register_user("user", "pass")
+        tok = self.login("user", "pass")
+
+        # Do an initial sync on a different device.
+        requester = create_requester(user)
+        initial_result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester, sync_config=generate_sync_config(user, device_id="dev")
+            )
+        )
+
+        # Create a room as the user.
+        joined_room = self.helper.create_room_as(user, tok=tok)
+
+        # Invite the user to the room as someone else.
+        invite_room = self.helper.create_room_as(inviter, tok=inviter_tok)
+        self.helper.invite(invite_room, targ=user, tok=inviter_tok)
+
+        knock_room = self.helper.create_room_as(
+            inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok
+        )
+        self.helper.send_state(
+            knock_room,
+            EventTypes.JoinRules,
+            {"join_rule": JoinRules.KNOCK},
+            tok=inviter_tok,
+        )
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/knock/%s" % (knock_room,),
+            b"{}",
+            tok,
+        )
+        self.assertEquals(200, channel.code, channel.result)
+
+        # The rooms should appear in the sync response.
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester, sync_config=generate_sync_config(user)
+            )
+        )
+        self.assertIn(joined_room, [r.room_id for r in result.joined])
+        self.assertIn(invite_room, [r.room_id for r in result.invited])
+        self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+        # Test a incremental sync (by providing a since_token).
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester,
+                sync_config=generate_sync_config(user, device_id="dev"),
+                since_token=initial_result.next_batch,
+            )
+        )
+        self.assertIn(joined_room, [r.room_id for r in result.joined])
+        self.assertIn(invite_room, [r.room_id for r in result.invited])
+        self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+        # Poke the database and update the room version to an unknown one.
+        for room_id in (joined_room, invite_room, knock_room):
+            self.get_success(
+                self.hs.get_datastores().main.db_pool.simple_update(
+                    "rooms",
+                    keyvalues={"room_id": room_id},
+                    updatevalues={"room_version": "unknown-room-version"},
+                    desc="updated-room-version",
+                )
+            )
+
+        # Blow away caches (supported room versions can only change due to a restart).
+        self.get_success(
+            self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+        )
+        self.store._get_event_cache.clear()
+
+        # The rooms should be excluded from the sync response.
+        # Get a new request key.
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester, sync_config=generate_sync_config(user)
+            )
+        )
+        self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+        self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+        self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+        # The rooms should also not be in an incremental sync.
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester,
+                sync_config=generate_sync_config(user, device_id="dev"),
+                since_token=initial_result.next_batch,
+            )
+        )
+        self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+        self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+        self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+
+_request_key = 0
+
 
-def generate_sync_config(user_id: str) -> SyncConfig:
+def generate_sync_config(
+    user_id: str, device_id: Optional[str] = "device_id"
+) -> SyncConfig:
+    """Generate a sync config (with a unique request key)."""
+    global _request_key
+    _request_key += 1
     return SyncConfig(
-        user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+        user=UserID.from_string(user_id),
         filter_collection=DEFAULT_FILTER_COLLECTION,
         is_guest=False,
-        request_key="request_key",
-        device_id="device_id",
+        request_key=("request_key", _request_key),
+        device_id=device_id,
     )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index db80a0bdbd..b25a06b427 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
 from synapse.handlers.room import RoomEventSource
 from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
 from synapse.types import PersistedEventPosition
 
 from tests.server import FakeTransport
@@ -150,6 +150,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                     "invite",
                     event.event_id,
                     event.internal_metadata.stream_ordering,
+                    RoomVersions.V1.identifier,
                 )
             ],
         )
@@ -216,7 +217,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_rooms_for_user_with_stream_ordering",
             (USER_ID_2,),
-            {(ROOM_ID, expected_pos)},
+            {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
         )
 
     def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -305,7 +306,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                 expected_pos = PersistedEventPosition(
                     "master", j2.internal_metadata.stream_ordering
                 )
-                self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
+                self.assertEqual(
+                    joined_rooms,
+                    {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+                )
 
     event_id = 0
 
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index c4afe5c3d9..a3679be205 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import urllib.parse
 
+from parameterized import parameterized
+
 import synapse.rest.admin
 from synapse.api.errors import Codes
 from synapse.rest.client import login
@@ -45,49 +46,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             self.other_user_device_id,
         )
 
-    def test_no_auth(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_no_auth(self, method: str):
         """
         Try to get a device of an user without authentication.
         """
-        channel = self.make_request("GET", self.url, b"{}")
-
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
-        channel = self.make_request("PUT", self.url, b"{}")
-
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
-        channel = self.make_request("DELETE", self.url, b"{}")
+        channel = self.make_request(method, self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-    def test_requester_is_no_admin(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_requester_is_no_admin(self, method: str):
         """
         If the user is not a server admin, an error is returned.
         """
         channel = self.make_request(
-            "GET",
-            self.url,
-            access_token=self.other_user_token,
-        )
-
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "PUT",
-            self.url,
-            access_token=self.other_user_token,
-        )
-
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "DELETE",
+            method,
             self.url,
             access_token=self.other_user_token,
         )
@@ -95,7 +70,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_user_does_not_exist(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_user_does_not_exist(self, method: str):
         """
         Tests that a lookup for a user that does not exist returns a 404
         """
@@ -105,7 +81,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         )
 
         channel = self.make_request(
-            "GET",
+            method,
             url,
             access_token=self.admin_user_tok,
         )
@@ -113,25 +89,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        channel = self.make_request(
-            "PUT",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "DELETE",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
-    def test_user_is_not_local(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_user_is_not_local(self, method: str):
         """
         Tests that a lookup for a user that is not a local returns a 400
         """
@@ -141,25 +100,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         )
 
         channel = self.make_request(
-            "GET",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
-        channel = self.make_request(
-            "PUT",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
-        channel = self.make_request(
-            "DELETE",
+            method,
             url,
             access_token=self.admin_user_tok,
         )
@@ -219,12 +160,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
         }
 
-        body = json.dumps(update)
         channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content=update,
         )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -275,12 +215,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         Tests a normal successful update of display name
         """
         # Set new display_name
-        body = json.dumps({"display_name": "new displayname"})
         channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"display_name": "new displayname"},
         )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -529,12 +468,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         Tests that a remove of a device that does not exist returns 200.
         """
-        body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
         channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"devices": ["unknown_device1", "unknown_device2"]},
         )
 
         # Delete unknown devices returns status 200
@@ -560,12 +498,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
             device_ids.append(str(d["device_id"]))
 
         # Delete devices
-        body = json.dumps({"devices": device_ids})
         channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"devices": device_ids},
         )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.url = "/_synapse/admin/v1/registration_tokens"
+
+    def _new_token(self, **kwargs):
+        """Helper function to create a token."""
+        token = kwargs.get(
+            "token",
+            "".join(random.choices(string.ascii_letters, k=8)),
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": kwargs.get("uses_allowed", None),
+                    "pending": kwargs.get("pending", 0),
+                    "completed": kwargs.get("completed", 0),
+                    "expiry_time": kwargs.get("expiry_time", None),
+                },
+            )
+        )
+        return token
+
+    # CREATION
+
+    def test_create_no_auth(self):
+        """Try to create a token without authentication."""
+        channel = self.make_request("POST", self.url + "/new", {})
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_create_requester_not_admin(self):
+        """Try to create a token while not an admin."""
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_create_using_defaults(self):
+        """Create a token using all the defaults."""
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 16)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_specifying_fields(self):
+        """Create a token specifying the value of all fields."""
+        data = {
+            "token": "abcd",
+            "uses_allowed": 1,
+            "expiry_time": self.clock.time_msec() + 1000000,
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["token"], "abcd")
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_with_null_value(self):
+        """Create a token specifying unlimited uses and no expiry."""
+        data = {
+            "uses_allowed": None,
+            "expiry_time": None,
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 16)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_token_too_long(self):
+        """Check token longer than 64 chars is invalid."""
+        data = {"token": "a" * 65}
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_token_invalid_chars(self):
+        """Check you can't create token with invalid characters."""
+        data = {
+            "token": "abc/def",
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_token_already_exists(self):
+        """Check you can't create token that already exists."""
+        data = {
+            "token": "abcd",
+        }
+
+        channel1 = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+        channel2 = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+        self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_unable_to_generate_token(self):
+        """Check right error is raised when server can't generate unique token."""
+        # Create all possible single character tokens
+        tokens = []
+        for c in string.ascii_letters + string.digits + "-_":
+            tokens.append(
+                {
+                    "token": c,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                }
+            )
+        self.get_success(
+            self.store.db_pool.simple_insert_many(
+                "registration_tokens",
+                tokens,
+                "create_all_registration_tokens",
+            )
+        )
+
+        # Check creating a single character token fails with a 500 status code
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 1},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+    def test_create_uses_allowed(self):
+        """Check you can only create a token with good values for uses_allowed."""
+        # Should work with 0 (token is invalid from the start)
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+        # Should fail with negative integer
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": 1.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_expiry_time(self):
+        """Check you can't create a token with an invalid expiry_time."""
+        # Should fail with a time in the past
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"expiry_time": self.clock.time_msec() - 10000},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"expiry_time": self.clock.time_msec() + 1000000.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_length(self):
+        """Check you can only generate a token with a valid length."""
+        # Should work with 64
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 64},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 64)
+
+        # Should fail with 0
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a negative integer
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 8.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with 65
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 65},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    # UPDATING
+
+    def test_update_no_auth(self):
+        """Try to update a token without authentication."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_update_requester_not_admin(self):
+        """Try to update a token while not an admin."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_update_non_existent(self):
+        """Try to update a token that doesn't exist."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",
+            {"uses_allowed": 1},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_update_uses_allowed(self):
+        """Test updating just uses_allowed."""
+        # Create new token using default values
+        token = self._new_token()
+
+        # Should succeed with 1
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 1},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should succeed with 0 (makes token invalid)
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 0)
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should succeed with null
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": None},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should fail with a float
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 1.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a negative integer
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_update_expiry_time(self):
+        """Test updating just expiry_time."""
+        # Create new token using default values
+        token = self._new_token()
+        new_expiry_time = self.clock.time_msec() + 1000000
+
+        # Should succeed with a time in the future
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": new_expiry_time},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+
+        # Should succeed with null
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": None},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertIsNone(channel.json_body["uses_allowed"])
+
+        # Should fail with a time in the past
+        past_time = self.clock.time_msec() - 10000
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": past_time},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail a float
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": new_expiry_time + 0.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_update_both(self):
+        """Test updating both uses_allowed and expiry_time."""
+        # Create new token using default values
+        token = self._new_token()
+        new_expiry_time = self.clock.time_msec() + 1000000
+
+        data = {
+            "uses_allowed": 1,
+            "expiry_time": new_expiry_time,
+        }
+
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+    def test_update_invalid_type(self):
+        """Test using invalid types doesn't work."""
+        # Create new token using default values
+        token = self._new_token()
+
+        data = {
+            "uses_allowed": False,
+            "expiry_time": "1626430124000",
+        }
+
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    # DELETING
+
+    def test_delete_no_auth(self):
+        """Try to delete a token without authentication."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_delete_requester_not_admin(self):
+        """Try to delete a token while not an admin."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_delete_non_existent(self):
+        """Try to delete a token that doesn't exist."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_delete(self):
+        """Test deleting a token."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/" + token,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+    # GETTING ONE
+
+    def test_get_no_auth(self):
+        """Try to get a token without authentication."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_get_requester_not_admin(self):
+        """Try to get a token while not an admin."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_get_non_existent(self):
+        """Try to get a token that doesn't exist."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_get(self):
+        """Test getting a token."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "GET",
+            self.url + "/" + token,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["token"], token)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    # LISTING
+
+    def test_list_no_auth(self):
+        """Try to list tokens without authentication."""
+        channel = self.make_request("GET", self.url, {})
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_list_requester_not_admin(self):
+        """Try to list tokens while not an admin."""
+        channel = self.make_request(
+            "GET",
+            self.url,
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_list_all(self):
+        """Test listing all tokens."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+        token_info = channel.json_body["registration_tokens"][0]
+        self.assertEqual(token_info["token"], token)
+        self.assertIsNone(token_info["uses_allowed"])
+        self.assertIsNone(token_info["expiry_time"])
+        self.assertEqual(token_info["pending"], 0)
+        self.assertEqual(token_info["completed"], 0)
+
+    def test_list_invalid_query_parameter(self):
+        """Test with `valid` query parameter not `true` or `false`."""
+        channel = self.make_request(
+            "GET",
+            self.url + "?valid=x",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+    def _test_list_query_parameter(self, valid: str):
+        """Helper used to test both valid=true and valid=false."""
+        # Create 2 valid and 2 invalid tokens.
+        now = self.hs.get_clock().time_msec()
+        # Create always valid token
+        valid1 = self._new_token()
+        # Create token that hasn't been used up
+        valid2 = self._new_token(uses_allowed=1)
+        # Create token that has expired
+        invalid1 = self._new_token(expiry_time=now - 10000)
+        # Create token that has been used up but hasn't expired
+        invalid2 = self._new_token(
+            uses_allowed=2,
+            pending=1,
+            completed=1,
+            expiry_time=now + 1000000,
+        )
+
+        if valid == "true":
+            tokens = [valid1, valid2]
+        else:
+            tokens = [invalid1, invalid2]
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?valid=" + valid,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+        token_info_1 = channel.json_body["registration_tokens"][0]
+        token_info_2 = channel.json_body["registration_tokens"][1]
+        self.assertIn(token_info_1["token"], tokens)
+        self.assertIn(token_info_2["token"], tokens)
+
+    def test_list_valid(self):
+        """Test listing just valid tokens."""
+        self._test_list_query_parameter(valid="true")
+
+    def test_list_invalid(self):
+        """Test listing just invalid tokens."""
+        self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index c9d4731017..40e032df7f 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -29,123 +29,6 @@ from tests import unittest
 """Tests admin REST events for /rooms paths."""
 
 
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
-    servlets = [
-        synapse.rest.admin.register_servlets_for_client_rest_resource,
-        login.register_servlets,
-        events.register_servlets,
-        room.register_servlets,
-        room.register_deprecated_servlets,
-    ]
-
-    def prepare(self, reactor, clock, hs):
-        self.event_creation_handler = hs.get_event_creation_handler()
-        hs.config.user_consent_version = "1"
-
-        consent_uri_builder = Mock()
-        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
-        self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
-        self.store = hs.get_datastore()
-
-        self.admin_user = self.register_user("admin", "pass", admin=True)
-        self.admin_user_tok = self.login("admin", "pass")
-
-        self.other_user = self.register_user("user", "pass")
-        self.other_user_token = self.login("user", "pass")
-
-        # Mark the admin user as having consented
-        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
-    def test_shutdown_room_consent(self):
-        """Test that we can shutdown rooms with local users who have not
-        yet accepted the privacy policy. This used to fail when we tried to
-        force part the user from the old room.
-        """
-        self.event_creation_handler._block_events_without_consent_error = None
-
-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
-        # Assert one user in room
-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
-        self.assertEqual([self.other_user], users_in_room)
-
-        # Enable require consent to send events
-        self.event_creation_handler._block_events_without_consent_error = "Error"
-
-        # Assert that the user is getting consent error
-        self.helper.send(
-            room_id, body="foo", tok=self.other_user_token, expect_code=403
-        )
-
-        # Test that the admin can still send shutdown
-        url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        channel = self.make_request(
-            "POST",
-            url.encode("ascii"),
-            json.dumps({"new_room_user_id": self.admin_user}),
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Assert there is now no longer anyone in the room
-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
-        self.assertEqual([], users_in_room)
-
-    def test_shutdown_room_block_peek(self):
-        """Test that a world_readable room can no longer be peeked into after
-        it has been shut down.
-        """
-
-        self.event_creation_handler._block_events_without_consent_error = None
-
-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
-        # Enable world readable
-        url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
-        channel = self.make_request(
-            "PUT",
-            url.encode("ascii"),
-            json.dumps({"history_visibility": "world_readable"}),
-            access_token=self.other_user_token,
-        )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Test that the admin can still send shutdown
-        url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        channel = self.make_request(
-            "POST",
-            url.encode("ascii"),
-            json.dumps({"new_room_user_id": self.admin_user}),
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Assert we can no longer peek into the room
-        self._assert_peek(room_id, expect_code=403)
-
-    def _assert_peek(self, room_id, expect_code):
-        """Assert that the admin user can (or cannot) peek into the room."""
-
-        url = "rooms/%s/initialSync" % (room_id,)
-        channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok
-        )
-        self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
-        )
-
-        url = "events?timeout=0&room_id=" + room_id
-        channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok
-        )
-        self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
-        )
-
-
 @parameterized_class(
     ("method", "url_template"),
     [
@@ -557,51 +440,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         )
 
 
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
-    """Test /purge_room admin API."""
-
-    servlets = [
-        synapse.rest.admin.register_servlets,
-        login.register_servlets,
-        room.register_servlets,
-    ]
-
-    def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
-        self.admin_user = self.register_user("admin", "pass", admin=True)
-        self.admin_user_tok = self.login("admin", "pass")
-
-    def test_purge_room(self):
-        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
-        # All users have to have left the room.
-        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
-        url = "/_synapse/admin/v1/purge_room"
-        channel = self.make_request(
-            "POST",
-            url.encode("ascii"),
-            {"room_id": room_id},
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Test that the following tables have been purged of all rows related to the room.
-        for table in PURGE_TABLES:
-            count = self.get_success(
-                self.store.db_pool.simple_select_one_onecol(
-                    table=table,
-                    keyvalues={"room_id": room_id},
-                    retcol="COUNT(*)",
-                    desc="test_purge_room",
-                )
-            )
-
-            self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
-
-
 class RoomTestCase(unittest.HomeserverTestCase):
     """Test /room admin API."""
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ef77275238..ee204c404b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual(1, len(channel.json_body["threepids"]))
         self.assertEqual(
             "external_id1", channel.json_body["external_ids"][0]["external_id"]
         )
         self.assertEqual(
             "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
         )
+        self.assertEqual(1, len(channel.json_body["external_ids"]))
         self.assertFalse(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
         self._check_fields(channel.json_body)
@@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Test setting threepid for an other user.
         """
 
-        # Delete old and add new threepid to user
+        # Add two threepids to user
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
+            content={
+                "threepids": [
+                    {"medium": "email", "address": "bob1@bob.bob"},
+                    {"medium": "email", "address": "bob2@bob.bob"},
+                ],
+            },
         )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["threepids"]))
+        # result does not always have the same sort order, therefore it becomes sorted
+        sorted_result = sorted(
+            channel.json_body["threepids"], key=lambda k: k["address"]
+        )
+        self.assertEqual("email", sorted_result[0]["medium"])
+        self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+        self.assertEqual("email", sorted_result[1]["medium"])
+        self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
+        self._check_fields(channel.json_body)
+
+        # Set a new and remove a threepid
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={
+                "threepids": [
+                    {"medium": "email", "address": "bob2@bob.bob"},
+                    {"medium": "email", "address": "bob3@bob.bob"},
+                ],
+            },
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["threepids"]))
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+        self._check_fields(channel.json_body)
 
         # Get user
         channel = self.make_request(
@@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["threepids"]))
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+        self._check_fields(channel.json_body)
+
+        # Remove threepids
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"threepids": []},
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self._check_fields(channel.json_body)
 
     def test_set_external_id(self):
         """
@@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["external_ids"]))
         self.assertEqual(
             channel.json_body["external_ids"],
             [
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/test_account.py
index b946fca8b3..b946fca8b3 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/test_account.py
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/test_auth.py
index cf5cfb910c..e2fcbdc63a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -25,7 +25,7 @@ from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.unittest import override_config, skip_unless
 
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 13b3c5f499..422361b62a 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.url = b"/_matrix/client/r0/capabilities"
         hs = self.setup_test_homeserver()
-        self.store = hs.get_datastore()
         self.config = hs.config
         self.auth_handler = hs.get_auth_handler()
         return hs
 
+    def prepare(self, reactor, clock, hs):
+        self.localpart = "user"
+        self.password = "pass"
+        self.user = self.register_user(self.localpart, self.password)
+
     def test_check_auth_required(self):
         channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.code, 401)
 
     def test_get_room_version_capabilities(self):
-        self.register_user("user", "pass")
-        access_token = self.login("user", "pass")
+        access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
@@ -57,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_change_password_capabilities_password_login(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
-        access_token = self.login(user, password)
+        access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
@@ -70,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     @override_config({"password_config": {"localdb_enabled": False}})
     def test_get_change_password_capabilities_localdb_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -87,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     @override_config({"password_config": {"enabled": False}})
     def test_get_change_password_capabilities_password_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -102,14 +96,86 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertFalse(capabilities["m.change_password"]["enabled"])
 
+    def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
+        """Test that per default msc3283 is disabled server returns `m.change_password`."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertTrue(capabilities["m.change_password"]["enabled"])
+        self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
+        self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
+        self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
+
+    @override_config({"experimental_features": {"msc3283_enabled": True}})
+    def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
+        """Test if msc3283 is enabled server returns capabilities."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertTrue(capabilities["m.change_password"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+    @override_config(
+        {
+            "enable_set_displayname": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_get_set_displayname_capabilities_displayname_disabled(self):
+        """Test if set displayname is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+
+    @override_config(
+        {
+            "enable_set_avatar_url": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+        """Test if set avatar_url is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+
+    @override_config(
+        {
+            "enable_3pid_changes": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_change_3pid_capabilities_3pid_disabled(self):
+        """Test if change 3pid is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
     @override_config({"experimental_features": {"msc3244_enabled": False}})
     def test_get_does_not_include_msc3244_fields_when_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -122,12 +188,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_does_include_msc3244_fields_when_enabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..d2181ea907 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/test_directory.py
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/test_events.py
index a90294003e..a90294003e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/test_events.py
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/test_filter.py
index 475c6bed3d..475c6bed3d 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/test_filter.py
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
new file mode 100644
index 0000000000..d7fa635eae
--- /dev/null
+++ b/tests/rest/client/test_keys.py
@@ -0,0 +1,91 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License
+
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        keys.register_servlets,
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+    ]
+
+    def test_rejects_device_id_ice_key_outside_of_list(self):
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        bob = self.register_user("bob", "uncle")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    bob: "device_id1",
+                },
+            },
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
+
+    def test_rejects_device_key_given_as_map_to_bool(self):
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        bob = self.register_user("bob", "uncle")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    bob: {
+                        "device_id1": True,
+                    },
+                },
+            },
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
+
+    def test_requires_device_key(self):
+        """`device_keys` is required. We should complain if it's missing."""
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {},
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/test_login.py
index eba3552b19..5b2243fe52 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -32,7 +32,7 @@ from synapse.types import create_requester
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
-from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3cf5871899..3cf5871899 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/test_presence.py
index 1d152352d1..1d152352d1 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/test_presence.py
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/test_profile.py
index 2860579c2e..2860579c2e 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/test_profile.py
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index d0ce91ccd9..d0ce91ccd9 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
 
 from tests import unittest
 from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_requires_token(self):
+        username = "kermit"
+        device_id = "frogfone"
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params = {
+            "username": username,
+            "password": "monkey",
+            "device_id": device_id,
+        }
+
+        # Request without auth to get flows and session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+        # Synapse adds a dummy stage to differentiate flows where otherwise one
+        # flow would be a subset of another flow.
+        self.assertCountEqual(
+            [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+            (f["stages"] for f in flows),
+        )
+        session = channel.json_body["session"]
+
+        # Do the registration token stage and check it has completed
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+        # Do the m.login.dummy stage and check registration was successful
+        params["auth"] = {
+            "type": LoginType.DUMMY,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        det_data = {
+            "user_id": f"@{username}:{self.hs.hostname}",
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+        # Check the `completed` counter has been incremented and pending is 0
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["completed"], 1)
+        self.assertEquals(res["pending"], 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_invalid(self):
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+        }
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Test with token param missing (invalid)
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with non-string (invalid)
+        params["auth"]["token"] = 1234
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with unknown token (invalid)
+        params["auth"]["token"] = "1234"
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_limit_uses(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        # Create token that can be used once
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": 1,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        # Do 2 requests without auth to get two session IDs
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with session1 and check `pending` is 1
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Repeat request to make sure pending isn't increased again
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 1)
+
+        # Check auth fails when using token with session2
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Check pending=0 and completed=1
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["pending"], 0)
+        self.assertEquals(res["completed"], 1)
+
+        # Check auth still fails when using token with session2
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_expiry(self):
+        token = "abcd"
+        now = self.hs.get_clock().time_msec()
+        store = self.hs.get_datastore()
+        # Create token that expired yesterday
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": now - 24 * 60 * 60 * 1000,
+                },
+            )
+        )
+        params = {"username": "kermit", "password": "monkey"}
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Check authentication fails with expired token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Update token so it expires tomorrow
+        self.get_success(
+            store.db_pool.simple_update_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+            )
+        )
+
+        # Check authentication succeeds
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry(self):
+        """Test `pending` is decremented when an uncompleted session expires."""
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do 2 requests without auth to get two session IDs
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with both sessions
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params2))
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        # Check `result` of registration token stage for session1 is `True`
+        result1 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session1,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertTrue(db_to_json(result1))
+
+        # Check `result` for session2 is the token used
+        result2 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session2,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertEquals(db_to_json(result2), token)
+
+        # Delete both sessions (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
+        # Check pending is now 0
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry_deleted_token(self):
+        """Test session expiry doesn't break when the token is deleted.
+
+        1. Start but don't complete UIA with a registration token
+        2. Delete the token from the database
+        3. Expire the session
+        """
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do request without auth to get a session ID
+        params = {"username": "kermit", "password": "monkey"}
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Use token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params))
+
+        # Delete token
+        self.get_success(
+            store.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+            )
+        )
+
+        # Delete session (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
     def test_advertised_flows(self):
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
         self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
         self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [register.register_servlets]
+    url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+    def default_config(self):
+        config = super().default_config()
+        config["registration_requires_token"] = True
+        return config
+
+    def test_GET_token_valid(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], True)
+
+    def test_GET_token_invalid(self):
+        token = "1234"
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], False)
+
+    @override_config(
+        {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+    )
+    def test_GET_ratelimiting(self):
+        token = "1234"
+
+        for i in range(0, 6):
+            channel = self.make_request(
+                b"GET",
+                f"{self.url}?token={token}",
+            )
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..02b5e9a8d0 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/test_relations.py
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/test_report_event.py
index ee6b0b9ebf..ee6b0b9ebf 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/test_rooms.py
index 0c9cbb9aff..0c9cbb9aff 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 6db7062a8e..6db7062a8e 100644
--- a/tests/rest/client/v2_alpha/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
index 283eccd53f..283eccd53f 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/test_shared_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/test_sync.py
index 95be369d4b..95be369d4b 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/test_sync.py
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/test_typing.py
index b54b004733..b54b004733 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/test_typing.py
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 72f976d8e2..72f976d8e2 100644
--- a/tests/rest/client/v2_alpha/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/utils.py
index 954ad1a1fd..954ad1a1fd 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/utils.py
diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py
deleted file mode 100644
index 5e83dba2ed..0000000000
--- a/tests/rest/client/v1/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tests/rest/client/v2_alpha/__init__.py
+++ /dev/null
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3785799f46..348fcb72a7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Send the join, it should return None (which is not an error)
         self.assertEqual(
-            self.get_success(
-                self.handler.on_receive_pdu(
-                    "test.serv", join_event, sent_to_us_directly=True
-                )
-            ),
+            self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
             None,
         )
 
@@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         with LoggingContext("test-context"):
             failure = self.get_failure(
-                self.handler.on_receive_pdu(
-                    "test.serv", lying_event, sent_to_us_directly=True
-                ),
+                self.handler.on_receive_pdu("test.serv", lying_event),
                 FederationError,
             )
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 3eec9c4d5b..f2c90cc47b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase):
             reactor=self.reactor,
         )
 
-        from tests.rest.client.v1.utils import RestHelper
+        from tests.rest.client.utils import RestHelper
 
         self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))