summary refs log tree commit diff
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
commitdcf1b9c276e22bb6f5200fc029301c4d40e87a1f (patch)
tree1f5badce24645d99534133a7a989069906088fff
parentMerge remote-tracking branch 'origin/release-v1.24.0' into bbz/info-mainline-... (diff)
parentFixup CHANGES (diff)
downloadsynapse-dcf1b9c276e22bb6f5200fc029301c4d40e87a1f.tar.xz
Merge remote-tracking branch 'origin/release-v1.27.0' into bbz/info-mainline-1.27.0 github/bbz/info-mainline-1.27.0 bbz/info-mainline-1.27.0
-rwxr-xr-x.buildkite/scripts/create_postgres_db.py1
-rwxr-xr-x.buildkite/scripts/test_old_deps.sh3
-rw-r--r--.circleci/config.yml15
-rw-r--r--.gitignore2
-rw-r--r--CHANGES.md408
-rw-r--r--INSTALL.md273
-rw-r--r--README.rst23
-rw-r--r--UPGRADE.rst147
-rw-r--r--contrib/prometheus/synapse-v2.rules18
-rwxr-xr-xdebian/build_virtualenv4
-rw-r--r--debian/changelog37
-rw-r--r--debian/control7
-rw-r--r--demo/webserver.py59
-rw-r--r--docker/Dockerfile2
-rw-r--r--docker/Dockerfile-dhvirtualenv13
-rw-r--r--docker/conf/homeserver.yaml10
-rw-r--r--docs/admin_api/media_admin_api.md140
-rw-r--r--docs/admin_api/purge_remote_media.rst20
-rw-r--r--docs/admin_api/purge_room.md9
-rw-r--r--docs/admin_api/rooms.md165
-rw-r--r--docs/admin_api/shutdown_room.md7
-rw-r--r--docs/admin_api/user_admin_api.rst64
-rw-r--r--docs/auth_chain_diff.dot32
-rw-r--r--docs/auth_chain_diff.dot.pngbin0 -> 42427 bytes
-rw-r--r--docs/auth_chain_difference_algorithm.md108
-rw-r--r--docs/dev/cas.md6
-rw-r--r--docs/openid.md300
-rw-r--r--docs/postgres.md2
-rw-r--r--docs/sample_config.yaml567
-rw-r--r--docs/spam_checker.md19
-rw-r--r--docs/sso_mapping_providers.md32
-rw-r--r--docs/systemd-with-workers/README.md2
-rw-r--r--docs/turn-howto.md6
-rw-r--r--docs/workers.md39
-rw-r--r--mypy.ini61
-rwxr-xr-xscripts-dev/lint.sh3
-rw-r--r--scripts-dev/mypy_synapse_plugin.py2
-rwxr-xr-xscripts/generate_log_config4
-rwxr-xr-xscripts/synapse_port_db29
-rwxr-xr-xsetup.py3
-rw-r--r--stubs/frozendict.pyi11
-rw-r--r--stubs/sortedcontainers/sorteddict.pyi6
-rw-r--r--stubs/txredisapi.pyi25
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py16
-rw-r--r--synapse/api/auth_blocking.py7
-rw-r--r--synapse/api/constants.py9
-rw-r--r--synapse/api/room_versions.py32
-rw-r--r--synapse/app/_base.py149
-rw-r--r--synapse/app/generic_worker.py78
-rw-r--r--synapse/app/homeserver.py117
-rw-r--r--synapse/app/phone_stats_home.py9
-rw-r--r--synapse/config/_base.py102
-rw-r--r--synapse/config/_base.pyi17
-rw-r--r--synapse/config/_util.py35
-rw-r--r--synapse/config/auth.py (renamed from synapse/config/password.py)26
-rw-r--r--synapse/config/captcha.py4
-rw-r--r--synapse/config/cas.py18
-rw-r--r--synapse/config/consent_config.py2
-rw-r--r--synapse/config/emailconfig.py27
-rw-r--r--synapse/config/experimental.py29
-rw-r--r--synapse/config/federation.py43
-rw-r--r--synapse/config/groups.py2
-rw-r--r--synapse/config/homeserver.py6
-rw-r--r--synapse/config/logger.py2
-rw-r--r--synapse/config/oidc_config.py613
-rw-r--r--synapse/config/password_auth_providers.py5
-rw-r--r--synapse/config/ratelimiting.py32
-rw-r--r--synapse/config/registration.py15
-rw-r--r--synapse/config/repository.py26
-rw-r--r--synapse/config/room_directory.py2
-rw-r--r--synapse/config/saml2_config.py10
-rw-r--r--synapse/config/server.py105
-rw-r--r--synapse/config/spam_checker.py9
-rw-r--r--synapse/config/sso.py138
-rw-r--r--synapse/config/third_party_event_rules.py4
-rw-r--r--synapse/config/workers.py36
-rw-r--r--synapse/crypto/context_factory.py15
-rw-r--r--synapse/crypto/event_signing.py29
-rw-r--r--synapse/crypto/keyring.py210
-rw-r--r--synapse/events/__init__.py3
-rw-r--r--synapse/events/spamcheck.py55
-rw-r--r--synapse/events/utils.py16
-rw-r--r--synapse/federation/federation_base.py7
-rw-r--r--synapse/federation/federation_client.py127
-rw-r--r--synapse/federation/federation_server.py24
-rw-r--r--synapse/federation/sender/__init__.py50
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py6
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/account_data.py144
-rw-r--r--synapse/handlers/acme.py12
-rw-r--r--synapse/handlers/acme_issuing_service.py27
-rw-r--r--synapse/handlers/admin.py63
-rw-r--r--synapse/handlers/auth.py288
-rw-r--r--synapse/handlers/cas_handler.py356
-rw-r--r--synapse/handlers/deactivate_account.py18
-rw-r--r--synapse/handlers/device.py12
-rw-r--r--synapse/handlers/devicemessage.py48
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/e2e_keys.py223
-rw-r--r--synapse/handlers/e2e_room_keys.py91
-rw-r--r--synapse/handlers/federation.py15
-rw-r--r--synapse/handlers/groups_local.py85
-rw-r--r--synapse/handlers/identity.py41
-rw-r--r--synapse/handlers/initial_sync.py4
-rw-r--r--synapse/handlers/message.py48
-rw-r--r--synapse/handlers/oidc_handler.py745
-rw-r--r--synapse/handlers/profile.py12
-rw-r--r--synapse/handlers/read_marker.py5
-rw-r--r--synapse/handlers/receipts.py64
-rw-r--r--synapse/handlers/register.py16
-rw-r--r--synapse/handlers/room.py27
-rw-r--r--synapse/handlers/room_list.py101
-rw-r--r--synapse/handlers/room_member.py63
-rw-r--r--synapse/handlers/saml_handler.py166
-rw-r--r--synapse/handlers/search.py38
-rw-r--r--synapse/handlers/set_password.py10
-rw-r--r--synapse/handlers/sso.py736
-rw-r--r--synapse/handlers/state_deltas.py14
-rw-r--r--synapse/handlers/stats.py39
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/handlers/typing.py69
-rw-r--r--synapse/handlers/ui_auth/__init__.py15
-rw-r--r--synapse/handlers/user_directory.py94
-rw-r--r--synapse/http/__init__.py15
-rw-r--r--synapse/http/client.py101
-rw-r--r--synapse/http/endpoint.py79
-rw-r--r--synapse/http/federation/matrix_federation_agent.py17
-rw-r--r--synapse/http/federation/well_known_resolver.py25
-rw-r--r--synapse/http/matrixfederationclient.py49
-rw-r--r--synapse/http/proxyagent.py16
-rw-r--r--synapse/http/server.py59
-rw-r--r--synapse/http/site.py21
-rw-r--r--synapse/logging/context.py70
-rw-r--r--synapse/logging/opentracing.py2
-rw-r--r--synapse/metrics/background_process_metrics.py16
-rw-r--r--synapse/module_api/__init__.py10
-rw-r--r--synapse/notifier.py45
-rw-r--r--synapse/push/__init__.py108
-rw-r--r--synapse/push/action_generator.py15
-rw-r--r--synapse/push/baserules.py23
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py110
-rw-r--r--synapse/push/clientformat.py23
-rw-r--r--synapse/push/emailpusher.py114
-rw-r--r--synapse/push/httppusher.py127
-rw-r--r--synapse/push/mailer.py434
-rw-r--r--synapse/push/presentable_names.py74
-rw-r--r--synapse/push/push_rule_evaluator.py28
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/push/pusher.py34
-rw-r--r--synapse/push/pusherpool.py172
-rw-r--r--synapse/python_dependencies.py4
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/_base.py49
-rw-r--r--synapse/replication/http/account_data.py187
-rw-r--r--synapse/replication/http/login.py12
-rw-r--r--synapse/replication/slave/storage/_base.py10
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py20
-rw-r--r--synapse/replication/slave/storage/account_data.py40
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py40
-rw-r--r--synapse/replication/slave/storage/pushers.py17
-rw-r--r--synapse/replication/slave/storage/receipts.py35
-rw-r--r--synapse/replication/tcp/external_cache.py105
-rw-r--r--synapse/replication/tcp/handler.py49
-rw-r--r--synapse/replication/tcp/protocol.py3
-rw-r--r--synapse/replication/tcp/redis.py143
-rw-r--r--synapse/res/templates/notif.html2
-rw-r--r--synapse/res/templates/sso.css88
-rw-r--r--synapse/res/templates/sso_account_deactivated.html26
-rw-r--r--synapse/res/templates/sso_auth_account_details.html161
-rw-r--r--synapse/res/templates/sso_auth_account_details.js116
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html25
-rw-r--r--synapse/res/templates/sso_auth_confirm.html32
-rw-r--r--synapse/res/templates/sso_auth_success.html39
-rw-r--r--synapse/res/templates/sso_error.html101
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html31
-rw-r--r--synapse/res/templates/sso_new_user_consent.html39
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html34
-rw-r--r--synapse/rest/admin/__init__.py10
-rw-r--r--synapse/rest/admin/media.py64
-rw-r--r--synapse/rest/admin/rooms.py292
-rw-r--r--synapse/rest/admin/users.py109
-rw-r--r--synapse/rest/client/v1/login.py169
-rw-r--r--synapse/rest/client/v1/pusher.py15
-rw-r--r--synapse/rest/client/v1/room.py37
-rw-r--r--synapse/rest/client/v2_alpha/account.py66
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py22
-rw-r--r--synapse/rest/client/v2_alpha/auth.py53
-rw-r--r--synapse/rest/client/v2_alpha/devices.py12
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
-rw-r--r--synapse/rest/client/v2_alpha/keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/register.py43
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py3
-rw-r--r--synapse/rest/client/v2_alpha/tags.py11
-rw-r--r--synapse/rest/consent/consent_resource.py1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
-rw-r--r--synapse/rest/media/v1/_base.py84
-rw-r--r--synapse/rest/media/v1/config_resource.py14
-rw-r--r--synapse/rest/media/v1/download_resource.py18
-rw-r--r--synapse/rest/media/v1/filepath.py50
-rw-r--r--synapse/rest/media/v1/media_repository.py52
-rw-r--r--synapse/rest/media/v1/media_storage.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py123
-rw-r--r--synapse/rest/media/v1/storage_provider.py51
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py313
-rw-r--r--synapse/rest/media/v1/thumbnailer.py18
-rw-r--r--synapse/rest/media/v1/upload_resource.py16
-rw-r--r--synapse/rest/synapse/client/__init__.py54
-rw-r--r--synapse/rest/synapse/client/new_user_consent.py97
-rw-r--r--synapse/rest/synapse/client/oidc/__init__.py (renamed from synapse/rest/oidc/__init__.py)6
-rw-r--r--synapse/rest/synapse/client/oidc/callback_resource.py (renamed from synapse/rest/oidc/callback_resource.py)0
-rw-r--r--synapse/rest/synapse/client/pick_idp.py84
-rw-r--r--synapse/rest/synapse/client/pick_username.py131
-rw-r--r--synapse/rest/synapse/client/saml2/__init__.py (renamed from synapse/rest/saml2/__init__.py)8
-rw-r--r--synapse/rest/synapse/client/saml2/metadata_resource.py (renamed from synapse/rest/saml2/metadata_resource.py)0
-rw-r--r--synapse/rest/synapse/client/saml2/response_resource.py (renamed from synapse/rest/saml2/response_resource.py)0
-rw-r--r--synapse/rest/synapse/client/sso_register.py50
-rw-r--r--synapse/server.py80
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py3
-rw-r--r--synapse/server_notices/server_notices_manager.py3
-rw-r--r--synapse/state/__init__.py15
-rw-r--r--synapse/state/v2.py92
-rw-r--r--synapse/static/client/login/style.css5
-rw-r--r--synapse/storage/__init__.py9
-rw-r--r--synapse/storage/_base.py36
-rw-r--r--synapse/storage/background_updates.py111
-rw-r--r--synapse/storage/database.py52
-rw-r--r--synapse/storage/databases/main/__init__.py57
-rw-r--r--synapse/storage/databases/main/account_data.py254
-rw-r--r--synapse/storage/databases/main/client_ips.py102
-rw-r--r--synapse/storage/databases/main/deviceinbox.py250
-rw-r--r--synapse/storage/databases/main/devices.py36
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py133
-rw-r--r--synapse/storage/databases/main/event_federation.py189
-rw-r--r--synapse/storage/databases/main/event_push_actions.py106
-rw-r--r--synapse/storage/databases/main/events.py703
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py360
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py101
-rw-r--r--synapse/storage/databases/main/events_worker.py8
-rw-r--r--synapse/storage/databases/main/keys.py10
-rw-r--r--synapse/storage/databases/main/media_repository.py13
-rw-r--r--synapse/storage/databases/main/metrics.py56
-rw-r--r--synapse/storage/databases/main/profile.py2
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/pusher.py102
-rw-r--r--synapse/storage/databases/main/receipts.py108
-rw-r--r--synapse/storage/databases/main/registration.py134
-rw-r--r--synapse/storage/databases/main/room.py61
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/27local_invites.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite62
-rw-r--r--synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py82
-rw-r--r--synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres25
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql52
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres16
-rw-r--r--synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres32
-rw-r--r--synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql18
-rw-r--r--synapse/storage/databases/main/search.py7
-rw-r--r--synapse/storage/databases/main/stats.py22
-rw-r--r--synapse/storage/databases/main/tags.py20
-rw-r--r--synapse/storage/databases/main/transactions.py24
-rw-r--r--synapse/storage/databases/main/user_directory.py33
-rw-r--r--synapse/storage/databases/state/store.py4
-rw-r--r--synapse/storage/keys.py5
-rw-r--r--synapse/storage/persist_events.py200
-rw-r--r--synapse/storage/prepare_database.py126
-rw-r--r--synapse/storage/purge_events.py11
-rw-r--r--synapse/storage/relations.py44
-rw-r--r--synapse/storage/state.py35
-rw-r--r--synapse/storage/util/id_generators.py113
-rw-r--r--synapse/storage/util/sequence.py82
-rw-r--r--synapse/types.py16
-rw-r--r--synapse/util/async_helpers.py8
-rw-r--r--synapse/util/caches/deferred_cache.py2
-rw-r--r--synapse/util/distributor.py7
-rw-r--r--synapse/util/iterutils.py53
-rw-r--r--synapse/util/metrics.py11
-rw-r--r--synapse/util/module_loader.py65
-rw-r--r--synapse/util/stringutils.py111
-rw-r--r--synapse/util/templates.py115
-rw-r--r--synapse/visibility.py44
-rw-r--r--tests/api/test_filtering.py16
-rw-r--r--tests/app/test_frontend_proxy.py6
-rw-r--r--tests/app/test_openid_listener.py8
-rw-r--r--tests/config/test_util.py53
-rw-r--r--tests/crypto/test_keyring.py16
-rw-r--r--tests/events/test_utils.py185
-rw-r--r--tests/federation/test_complexity.py4
-rw-r--r--tests/federation/test_federation_server.py6
-rw-r--r--tests/federation/transport/__init__.py0
-rw-r--r--tests/federation/transport/test_server.py39
-rw-r--r--tests/handlers/test_cas.py121
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_directory.py14
-rw-r--r--tests/handlers/test_federation.py54
-rw-r--r--tests/handlers/test_message.py2
-rw-r--r--tests/handlers/test_oidc.py519
-rw-r--r--tests/handlers/test_password_providers.py29
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py32
-rw-r--r--tests/handlers/test_saml.py155
-rw-r--r--tests/handlers/test_typing.py19
-rw-r--r--tests/handlers/test_user_directory.py48
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py32
-rw-r--r--tests/http/test_additional_resource.py8
-rw-r--r--tests/http/test_client.py101
-rw-r--r--tests/http/test_endpoint.py2
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/http/test_proxyagent.py130
-rw-r--r--tests/logging/test_terse_json.py70
-rw-r--r--tests/push/test_email.py36
-rw-r--r--tests/push/test_http.py118
-rw-r--r--tests/push/test_presentable_names.py229
-rw-r--r--tests/push/test_push_rule_evaluator.py2
-rw-r--r--tests/replication/_base.py59
-rw-r--r--tests/replication/test_auth.py117
-rw-r--r--tests/replication/test_client_reader_shard.py36
-rw-r--r--tests/replication/test_federation_sender_shard.py10
-rw-r--r--tests/replication/test_multi_media_repo.py4
-rw-r--r--tests/replication/test_pusher_shard.py14
-rw-r--r--tests/replication/test_sharded_event_persister.py16
-rw-r--r--tests/rest/admin/test_admin.py45
-rw-r--r--tests/rest/admin/test_device.py114
-rw-r--r--tests/rest/admin/test_event_reports.py66
-rw-r--r--tests/rest/admin/test_media.py64
-rw-r--r--tests/rest/admin/test_room.py289
-rw-r--r--tests/rest/admin/test_statistics.py61
-rw-r--r--tests/rest/admin/test_user.py992
-rw-r--r--tests/rest/client/test_consent.py8
-rw-r--r--tests/rest/client/test_ephemeral_message.py2
-rw-r--r--tests/rest/client/test_identity.py6
-rw-r--r--tests/rest/client/test_redactions.py8
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/rest/client/test_shadow_banned.py24
-rw-r--r--tests/rest/client/test_third_party_rules.py10
-rw-r--r--tests/rest/client/v1/test_directory.py8
-rw-r--r--tests/rest/client/v1/test_events.py8
-rw-r--r--tests/rest/client/v1/test_login.py609
-rw-r--r--tests/rest/client/v1/test_presence.py6
-rw-r--r--tests/rest/client/v1/test_profile.py18
-rw-r--r--tests/rest/client/v1/test_push_rule_attrs.py66
-rw-r--r--tests/rest/client/v1/test_rooms.py283
-rw-r--r--tests/rest/client/v1/test_typing.py10
-rw-r--r--tests/rest/client/v1/utils.py268
-rw-r--r--tests/rest/client/v2_alpha/test_account.py148
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py260
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py8
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py14
-rw-r--r--tests/rest/client/v2_alpha/test_password_policy.py18
-rw-r--r--tests/rest/client/v2_alpha/test_register.py71
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py36
-rw-r--r--tests/rest/client/v2_alpha/test_shared_rooms.py16
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py32
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/rest/media/v1/test_media_storage.py44
-rw-r--r--tests/rest/media/v1/test_url_preview.py69
-rw-r--r--tests/rest/test_health.py4
-rw-r--r--tests/rest/test_well_known.py8
-rw-r--r--tests/server.py55
-rw-r--r--tests/server_notices/test_consent.py6
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py8
-rw-r--r--tests/state/test_v2.py193
-rw-r--r--tests/storage/test_account_data.py120
-rw-r--r--tests/storage/test_devices.py26
-rw-r--r--tests/storage/test_e2e_room_keys.py2
-rw-r--r--tests/storage/test_event_chain.py741
-rw-r--r--tests/storage/test_event_federation.py270
-rw-r--r--tests/storage/test_events.py334
-rw-r--r--tests/storage/test_id_generators.py112
-rw-r--r--tests/storage/test_main.py7
-rw-r--r--tests/storage/test_profile.py26
-rw-r--r--tests/storage/test_purge.py2
-rw-r--r--tests/storage/test_redaction.py11
-rw-r--r--tests/storage/test_roommember.py8
-rw-r--r--tests/storage/test_user_directory.py23
-rw-r--r--tests/test_federation.py4
-rw-r--r--tests/test_mau.py48
-rw-r--r--tests/test_preview.py67
-rw-r--r--tests/test_server.py30
-rw-r--r--tests/test_terms_auth.py6
-rw-r--r--tests/test_types.py4
-rw-r--r--tests/test_utils/__init__.py39
-rw-r--r--tests/test_utils/html_parsers.py53
-rw-r--r--tests/test_utils/logging_setup.py2
-rw-r--r--tests/unittest.py134
-rw-r--r--tests/util/caches/test_deferred_cache.py73
-rw-r--r--tests/util/test_itertools.py61
-rw-r--r--tests/utils.py112
-rw-r--r--tox.ini90
399 files changed, 20666 insertions, 6848 deletions
diff --git a/.buildkite/scripts/create_postgres_db.py b/.buildkite/scripts/create_postgres_db.py
index df6082b0ac..956339de5c 100755
--- a/.buildkite/scripts/create_postgres_db.py
+++ b/.buildkite/scripts/create_postgres_db.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+
 from synapse.storage.engines import create_engine
 
 logger = logging.getLogger("create_postgres_db")
diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh
index 9905c4bc4f..28e6694b5d 100755
--- a/.buildkite/scripts/test_old_deps.sh
+++ b/.buildkite/scripts/test_old_deps.sh
@@ -10,4 +10,7 @@ apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev x
 
 export LANG="C.UTF-8"
 
+# Prevent virtualenv from auto-updating pip to an incompatible version
+export VIRTUALENV_NO_DOWNLOAD=1
+
 exec tox -e py35-old,combine
diff --git a/.circleci/config.yml b/.circleci/config.yml
index b10cbedd6d..375a7f7b04 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -5,9 +5,10 @@ jobs:
       - image: docker:git
     steps:
       - checkout
-      - setup_remote_docker
       - docker_prepare
       - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
+      # for release builds, we want to get the amd64 image out asap, so first
+      # we do an amd64-only build, before following up with a multiarch build.
       - docker_build:
           tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
           platforms: linux/amd64
@@ -20,12 +21,10 @@ jobs:
       - image: docker:git
     steps:
       - checkout
-      - setup_remote_docker
       - docker_prepare
       - run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
-      - docker_build:
-          tag: -t matrixdotorg/synapse:latest
-          platforms: linux/amd64
+      # for `latest`, we don't want the arm images to disappear, so don't update the tag
+      # until all of the platforms are built.
       - docker_build:
           tag: -t matrixdotorg/synapse:latest
           platforms: linux/amd64,linux/arm/v7,linux/arm64
@@ -46,12 +45,16 @@ workflows:
 
 commands:
   docker_prepare:
-    description: Downloads the buildx cli plugin and enables multiarch images
+    description: Sets up a remote docker server, downloads the buildx cli plugin, and enables multiarch images
     parameters:
       buildx_version:
         type: string
         default: "v0.4.1"
     steps:
+      - setup_remote_docker:
+          # 19.03.13 was the most recent available on circleci at the time of
+          # writing.
+          version: 19.03.13
       - run: apk add --no-cache curl
       - run: mkdir -vp ~/.docker/cli-plugins/ ~/dockercache
       - run: curl --silent -L "https://github.com/docker/buildx/releases/download/<< parameters.buildx_version >>/buildx-<< parameters.buildx_version >>.linux-amd64" > ~/.docker/cli-plugins/docker-buildx
diff --git a/.gitignore b/.gitignore
index 9bb5bdd647..2cef1b0a5a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,10 +12,12 @@
 _trial_temp/
 _trial_temp*/
 /out
+.DS_Store
 
 # stuff that is likely to exist when you run a server locally
 /*.db
 /*.log
+/*.log.*
 /*.log.config
 /*.pid
 /.python-version
diff --git a/CHANGES.md b/CHANGES.md
index 677afeebc3..d9afcaa52b 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,359 @@
+Synapse 1.27.0 (2021-02-16)
+===========================
+
+Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled for workers. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically.
+
+This release also changes the callback URI for OpenID Connect (OIDC) identity providers. If your server is configured to use single sign-on via an OIDC/OAuth2 IdP, you may need to make configuration changes. Please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
+
+This release also changes escaping of variables in the HTML templates for SSO or email notifications. If you have customised these templates, please review [UPGRADE.rst](UPGRADE.rst) for more details on these changes.
+
+
+Bugfixes
+--------
+
+- Fix building Docker images for armv7. ([\#9405](https://github.com/matrix-org/synapse/issues/9405))
+
+
+Synapse 1.27.0rc2 (2021-02-11)
+==============================
+
+Features
+--------
+
+- Further improvements to the user experience of registration via single sign-on. ([\#9297](https://github.com/matrix-org/synapse/issues/9297))
+
+
+Bugfixes
+--------
+
+- Fix ratelimiting introduced in v1.27.0rc1 for invites to respect the `ratelimit` flag on application services. ([\#9302](https://github.com/matrix-org/synapse/issues/9302))
+- Do not automatically calculate `public_baseurl` since it can be wrong in some situations. Reverts behaviour introduced in v1.26.0. ([\#9313](https://github.com/matrix-org/synapse/issues/9313))
+
+
+Improved Documentation
+----------------------
+
+- Clarify the sample configuration for changes made to the template loading code. ([\#9310](https://github.com/matrix-org/synapse/issues/9310))
+
+
+Synapse 1.27.0rc1 (2021-02-02)
+==============================
+
+Features
+--------
+
+- Add an admin API for getting and deleting forward extremities for a room. ([\#9062](https://github.com/matrix-org/synapse/issues/9062))
+- Add an admin API for retrieving the current room state of a room. ([\#9168](https://github.com/matrix-org/synapse/issues/9168))
+- Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). ([\#9183](https://github.com/matrix-org/synapse/issues/9183), [\#9242](https://github.com/matrix-org/synapse/issues/9242))
+- Add an admin API endpoint for shadow-banning users. ([\#9209](https://github.com/matrix-org/synapse/issues/9209))
+- Add ratelimits to the 3PID `/requestToken` APIs. ([\#9238](https://github.com/matrix-org/synapse/issues/9238))
+- Add support to the OpenID Connect integration for adding the user's email address. ([\#9245](https://github.com/matrix-org/synapse/issues/9245))
+- Add ratelimits to invites in rooms and to specific users. ([\#9258](https://github.com/matrix-org/synapse/issues/9258))
+- Improve the user experience of setting up an account via single-sign on. ([\#9262](https://github.com/matrix-org/synapse/issues/9262), [\#9272](https://github.com/matrix-org/synapse/issues/9272), [\#9275](https://github.com/matrix-org/synapse/issues/9275), [\#9276](https://github.com/matrix-org/synapse/issues/9276), [\#9277](https://github.com/matrix-org/synapse/issues/9277), [\#9286](https://github.com/matrix-org/synapse/issues/9286), [\#9287](https://github.com/matrix-org/synapse/issues/9287))
+- Add phone home stats for encrypted messages. ([\#9283](https://github.com/matrix-org/synapse/issues/9283))
+- Update the redirect URI for OIDC authentication. ([\#9288](https://github.com/matrix-org/synapse/issues/9288))
+
+
+Bugfixes
+--------
+
+- Fix spurious errors in logs when deleting a non-existant pusher. ([\#9121](https://github.com/matrix-org/synapse/issues/9121))
+- Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled). ([\#9163](https://github.com/matrix-org/synapse/issues/9163))
+- Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding. ([\#9164](https://github.com/matrix-org/synapse/issues/9164))
+- Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push. ([\#9165](https://github.com/matrix-org/synapse/issues/9165))
+- Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data. ([\#9218](https://github.com/matrix-org/synapse/issues/9218))
+- Fix a bug where `None` was passed to Synapse modules instead of an empty dictionary if an empty module `config` block was provided in the homeserver config. ([\#9229](https://github.com/matrix-org/synapse/issues/9229))
+- Fix a bug in the `make_room_admin` admin API where it failed if the admin with the greatest power level was not in the room. Contributed by Pankaj Yadav. ([\#9235](https://github.com/matrix-org/synapse/issues/9235))
+- Prevent password hashes from getting dropped if a client failed threepid validation during a User Interactive Auth stage. Removes a workaround for an ancient bug in Riot Web <v0.7.4. ([\#9265](https://github.com/matrix-org/synapse/issues/9265))
+- Fix single-sign-on when the endpoints are routed to synapse workers. ([\#9271](https://github.com/matrix-org/synapse/issues/9271))
+
+
+Improved Documentation
+----------------------
+
+- Add docs for using Gitea as OpenID provider. ([\#9134](https://github.com/matrix-org/synapse/issues/9134))
+- Add link to Matrix VoIP tester for turn-howto. ([\#9135](https://github.com/matrix-org/synapse/issues/9135))
+- Add notes on integrating with Facebook for SSO login. ([\#9244](https://github.com/matrix-org/synapse/issues/9244))
+
+
+Deprecations and Removals
+-------------------------
+
+- The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`. ([\#9199](https://github.com/matrix-org/synapse/issues/9199))
+- Add new endpoint `/_synapse/client/saml2` for SAML2 authentication callbacks, and deprecate the old endpoint `/_matrix/saml2`. ([\#9289](https://github.com/matrix-org/synapse/issues/9289))
+
+
+Internal Changes
+----------------
+
+- Add tests to `test_user.UsersListTestCase` for List Users Admin API. ([\#9045](https://github.com/matrix-org/synapse/issues/9045))
+- Various improvements to the federation client. ([\#9129](https://github.com/matrix-org/synapse/issues/9129))
+- Speed up chain cover calculation when persisting a batch of state events at once. ([\#9176](https://github.com/matrix-org/synapse/issues/9176))
+- Add a `long_description_type` to the package metadata. ([\#9180](https://github.com/matrix-org/synapse/issues/9180))
+- Speed up batch insertion when using PostgreSQL. ([\#9181](https://github.com/matrix-org/synapse/issues/9181), [\#9188](https://github.com/matrix-org/synapse/issues/9188))
+- Emit an error at startup if different Identity Providers are configured with the same `idp_id`. ([\#9184](https://github.com/matrix-org/synapse/issues/9184))
+- Improve performance of concurrent use of `StreamIDGenerators`. ([\#9190](https://github.com/matrix-org/synapse/issues/9190))
+- Add some missing source directories to the automatic linting script. ([\#9191](https://github.com/matrix-org/synapse/issues/9191))
+- Precompute joined hosts and store in Redis. ([\#9198](https://github.com/matrix-org/synapse/issues/9198), [\#9227](https://github.com/matrix-org/synapse/issues/9227))
+- Clean-up template loading code. ([\#9200](https://github.com/matrix-org/synapse/issues/9200))
+- Fix the Python 3.5 old dependencies build. ([\#9217](https://github.com/matrix-org/synapse/issues/9217))
+- Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting. ([\#9222](https://github.com/matrix-org/synapse/issues/9222))
+- Add type hints to handlers code. ([\#9223](https://github.com/matrix-org/synapse/issues/9223), [\#9232](https://github.com/matrix-org/synapse/issues/9232))
+- Fix Debian package building on Ubuntu 16.04 LTS (Xenial). ([\#9254](https://github.com/matrix-org/synapse/issues/9254))
+- Minor performance improvement during TLS handshake. ([\#9255](https://github.com/matrix-org/synapse/issues/9255))
+- Refactor the generation of summary text for email notifications. ([\#9260](https://github.com/matrix-org/synapse/issues/9260))
+- Restore PyPy compatibility by not calling CPython-specific GC methods when under PyPy. ([\#9270](https://github.com/matrix-org/synapse/issues/9270))
+
+
+Synapse 1.26.0 (2021-01-27)
+===========================
+
+This release brings a new schema version for Synapse and rolling back to a previous
+version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
+on these changes and for general upgrade guidance.
+
+No significant changes since 1.26.0rc2.
+
+
+Synapse 1.26.0rc2 (2021-01-25)
+==============================
+
+Bugfixes
+--------
+
+- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
+- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
+
+
+Internal Changes
+----------------
+
+- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
+- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
+
+
+Synapse 1.26.0rc1 (2021-01-20)
+==============================
+
+This release brings a new schema version for Synapse and rolling back to a previous
+version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
+on these changes and for general upgrade guidance.
+
+Features
+--------
+
+- Add support for multiple SSO Identity Providers. ([\#9015](https://github.com/matrix-org/synapse/issues/9015), [\#9017](https://github.com/matrix-org/synapse/issues/9017), [\#9036](https://github.com/matrix-org/synapse/issues/9036), [\#9067](https://github.com/matrix-org/synapse/issues/9067), [\#9081](https://github.com/matrix-org/synapse/issues/9081), [\#9082](https://github.com/matrix-org/synapse/issues/9082), [\#9105](https://github.com/matrix-org/synapse/issues/9105), [\#9107](https://github.com/matrix-org/synapse/issues/9107), [\#9109](https://github.com/matrix-org/synapse/issues/9109), [\#9110](https://github.com/matrix-org/synapse/issues/9110), [\#9127](https://github.com/matrix-org/synapse/issues/9127), [\#9153](https://github.com/matrix-org/synapse/issues/9153), [\#9154](https://github.com/matrix-org/synapse/issues/9154), [\#9177](https://github.com/matrix-org/synapse/issues/9177))
+- During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. ([\#9091](https://github.com/matrix-org/synapse/issues/9091))
+- Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. ([\#9159](https://github.com/matrix-org/synapse/issues/9159))
+- Improve performance when calculating ignored users in large rooms. ([\#9024](https://github.com/matrix-org/synapse/issues/9024))
+- Implement [MSC2176](https://github.com/matrix-org/matrix-doc/pull/2176) in an experimental room version. ([\#8984](https://github.com/matrix-org/synapse/issues/8984))
+- Add an admin API for protecting local media from quarantine. ([\#9086](https://github.com/matrix-org/synapse/issues/9086))
+- Remove a user's avatar URL and display name when deactivated with the Admin API. ([\#8932](https://github.com/matrix-org/synapse/issues/8932))
+- Update `/_synapse/admin/v1/users/<user_id>/joined_rooms` to work for both local and remote users. ([\#8948](https://github.com/matrix-org/synapse/issues/8948))
+- Add experimental support for handling to-device messages on worker processes. ([\#9042](https://github.com/matrix-org/synapse/issues/9042), [\#9043](https://github.com/matrix-org/synapse/issues/9043), [\#9044](https://github.com/matrix-org/synapse/issues/9044), [\#9130](https://github.com/matrix-org/synapse/issues/9130))
+- Add experimental support for handling `/keys/claim` and `/room_keys` APIs on worker processes. ([\#9068](https://github.com/matrix-org/synapse/issues/9068))
+- Add experimental support for handling `/devices` API on worker processes. ([\#9092](https://github.com/matrix-org/synapse/issues/9092))
+- Add experimental support for moving off receipts and account data persistence off master. ([\#9104](https://github.com/matrix-org/synapse/issues/9104), [\#9166](https://github.com/matrix-org/synapse/issues/9166))
+
+
+Bugfixes
+--------
+
+- Fix a long-standing issue where an internal server error would occur when requesting a profile over federation that did not include a display name / avatar URL. ([\#9023](https://github.com/matrix-org/synapse/issues/9023))
+- Fix a long-standing bug where some caches could grow larger than configured. ([\#9028](https://github.com/matrix-org/synapse/issues/9028))
+- Fix error handling during insertion of client IPs into the database. ([\#9051](https://github.com/matrix-org/synapse/issues/9051))
+- Fix bug where we didn't correctly record CPU time spent in `on_new_event` block. ([\#9053](https://github.com/matrix-org/synapse/issues/9053))
+- Fix a minor bug which could cause confusing error messages from invalid configurations. ([\#9054](https://github.com/matrix-org/synapse/issues/9054))
+- Fix incorrect exit code when there is an error at startup. ([\#9059](https://github.com/matrix-org/synapse/issues/9059))
+- Fix `JSONDecodeError` spamming the logs when sending transactions to remote servers. ([\#9070](https://github.com/matrix-org/synapse/issues/9070))
+- Fix "Failed to send request" errors when a client provides an invalid room alias. ([\#9071](https://github.com/matrix-org/synapse/issues/9071))
+- Fix bugs in federation catchup logic that caused outbound federation to be delayed for large servers after start up. Introduced in v1.8.0 and v1.21.0. ([\#9114](https://github.com/matrix-org/synapse/issues/9114), [\#9116](https://github.com/matrix-org/synapse/issues/9116))
+- Fix corruption of `pushers` data when a postgres bouncer is used. ([\#9117](https://github.com/matrix-org/synapse/issues/9117))
+- Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. ([\#9128](https://github.com/matrix-org/synapse/issues/9128))
+- Fix "Unhandled error in Deferred: BodyExceededMaxSize" errors when .well-known files that are too large. ([\#9108](https://github.com/matrix-org/synapse/issues/9108))
+- Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. ([\#9145](https://github.com/matrix-org/synapse/issues/9145))
+- Fix a long-standing bug "ValueError: invalid literal for int() with base 10" when `/publicRooms` is requested with an invalid `server` parameter. ([\#9161](https://github.com/matrix-org/synapse/issues/9161))
+
+
+Improved Documentation
+----------------------
+
+- Add some extra docs for getting Synapse running on macOS. ([\#8997](https://github.com/matrix-org/synapse/issues/8997))
+- Correct a typo in the `systemd-with-workers` documentation. ([\#9035](https://github.com/matrix-org/synapse/issues/9035))
+- Correct a typo in `INSTALL.md`. ([\#9040](https://github.com/matrix-org/synapse/issues/9040))
+- Add missing `user_mapping_provider` configuration to the Keycloak OIDC example. Contributed by @chris-ruecker. ([\#9057](https://github.com/matrix-org/synapse/issues/9057))
+- Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters. ([\#9151](https://github.com/matrix-org/synapse/issues/9151))
+
+
+Deprecations and Removals
+-------------------------
+
+- Remove broken and unmaintained `demo/webserver.py` script. ([\#9039](https://github.com/matrix-org/synapse/issues/9039))
+
+
+Internal Changes
+----------------
+
+- Improve efficiency of large state resolutions. ([\#8868](https://github.com/matrix-org/synapse/issues/8868), [\#9029](https://github.com/matrix-org/synapse/issues/9029), [\#9115](https://github.com/matrix-org/synapse/issues/9115), [\#9118](https://github.com/matrix-org/synapse/issues/9118), [\#9124](https://github.com/matrix-org/synapse/issues/9124))
+- Various clean-ups to the structured logging and logging context code. ([\#8939](https://github.com/matrix-org/synapse/issues/8939))
+- Ensure rejected events get added to some metadata tables. ([\#9016](https://github.com/matrix-org/synapse/issues/9016))
+- Ignore date-rotated homeserver logs saved to disk. ([\#9018](https://github.com/matrix-org/synapse/issues/9018))
+- Remove an unused column from `access_tokens` table. ([\#9025](https://github.com/matrix-org/synapse/issues/9025))
+- Add a `-noextras` factor to `tox.ini`, to support running the tests with no optional dependencies. ([\#9030](https://github.com/matrix-org/synapse/issues/9030))
+- Fix running unit tests when optional dependencies are not installed. ([\#9031](https://github.com/matrix-org/synapse/issues/9031))
+- Allow bumping schema version when using split out state database. ([\#9033](https://github.com/matrix-org/synapse/issues/9033))
+- Configure the linters to run on a consistent set of files. ([\#9038](https://github.com/matrix-org/synapse/issues/9038))
+- Various cleanups to device inbox store. ([\#9041](https://github.com/matrix-org/synapse/issues/9041))
+- Drop unused database tables. ([\#9055](https://github.com/matrix-org/synapse/issues/9055))
+- Remove unused `SynapseService` class. ([\#9058](https://github.com/matrix-org/synapse/issues/9058))
+- Remove unnecessary declarations in the tests for the admin API. ([\#9063](https://github.com/matrix-org/synapse/issues/9063))
+- Remove `SynapseRequest.get_user_agent`. ([\#9069](https://github.com/matrix-org/synapse/issues/9069))
+- Remove redundant `Homeserver.get_ip_from_request` method. ([\#9080](https://github.com/matrix-org/synapse/issues/9080))
+- Add type hints to media repository. ([\#9093](https://github.com/matrix-org/synapse/issues/9093))
+- Fix the wrong arguments being passed to `BlacklistingAgentWrapper` from `MatrixFederationAgent`. Contributed by Timothy Leung. ([\#9098](https://github.com/matrix-org/synapse/issues/9098))
+- Reduce the scope of caught exceptions in `BlacklistingAgentWrapper`. ([\#9106](https://github.com/matrix-org/synapse/issues/9106))
+- Improve `UsernamePickerTestCase`. ([\#9112](https://github.com/matrix-org/synapse/issues/9112))
+- Remove dependency on `distutils`. ([\#9125](https://github.com/matrix-org/synapse/issues/9125))
+- Enforce that replication HTTP clients are called with keyword arguments only. ([\#9144](https://github.com/matrix-org/synapse/issues/9144))
+- Fix the Python 3.5 / old dependencies build in CI. ([\#9146](https://github.com/matrix-org/synapse/issues/9146))
+- Replace the old `perspectives` option in the Synapse docker config file template with `trusted_key_servers`. ([\#9157](https://github.com/matrix-org/synapse/issues/9157))
+
+
+Synapse 1.25.0 (2021-01-13)
+===========================
+
+Ending Support for Python 3.5 and Postgres 9.5
+----------------------------------------------
+
+With this release, the Synapse team is announcing a formal deprecation policy for our platform dependencies, like Python and PostgreSQL:
+
+All future releases of Synapse will follow the upstream end-of-life schedules.
+
+Which means:
+
+* This is the last release which guarantees support for Python 3.5.
+* We will end support for PostgreSQL 9.5 early next month.
+* We will end support for Python 3.6 and PostgreSQL 9.6 near the end of the year.
+
+Crucially, this means __we will not produce .deb packages for Debian 9 (Stretch) or Ubuntu 16.04 (Xenial)__ beyond the transition period described below.
+
+The website https://endoflife.date/ has convenient summaries of the support schedules for projects like [Python](https://endoflife.date/python) and [PostgreSQL](https://endoflife.date/postgresql).
+
+If you are unable to upgrade your environment to a supported version of Python or Postgres, we encourage you to consider using the [Synapse Docker images](./INSTALL.md#docker-images-and-ansible-playbooks) instead.
+
+### Transition Period
+
+We will make a good faith attempt to avoid breaking compatibility in all releases through the end of March 2021. However, critical security vulnerabilities in dependencies or other unanticipated circumstances may arise which necessitate breaking compatibility earlier.
+
+We intend to continue producing .deb packages for Debian 9 (Stretch) and Ubuntu 16.04 (Xenial) through the transition period.
+
+Removal warning
+---------------
+
+The old [Purge Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/purge_room.md)
+and [Shutdown Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/shutdown_room.md)
+are deprecated and will be removed in a future release. They will be replaced by the
+[Delete Room API](https://github.com/matrix-org/synapse/tree/master/docs/admin_api/rooms.md#delete-room-api).
+
+`POST /_synapse/admin/v1/rooms/<room_id>/delete` replaces `POST /_synapse/admin/v1/purge_room` and
+`POST /_synapse/admin/v1/shutdown_room/<room_id>`.
+
+Bugfixes
+--------
+
+- Fix HTTP proxy support when using a proxy that is on a blacklisted IP. Introduced in v1.25.0rc1. Contributed by @Bubu. ([\#9084](https://github.com/matrix-org/synapse/issues/9084))
+
+
+Synapse 1.25.0rc1 (2021-01-06)
+==============================
+
+Features
+--------
+
+- Add an admin API that lets server admins get power in rooms in which local users have power. ([\#8756](https://github.com/matrix-org/synapse/issues/8756))
+- Add optional HTTP authentication to replication endpoints. ([\#8853](https://github.com/matrix-org/synapse/issues/8853))
+- Improve the error messages printed as a result of configuration problems for extension modules. ([\#8874](https://github.com/matrix-org/synapse/issues/8874))
+- Add the number of local devices to Room Details Admin API. Contributed by @dklimpel. ([\#8886](https://github.com/matrix-org/synapse/issues/8886))
+- Add `X-Robots-Tag` header to stop web crawlers from indexing media. Contributed by Aaron Raimist. ([\#8887](https://github.com/matrix-org/synapse/issues/8887))
+- Spam-checkers may now define their methods as `async`. ([\#8890](https://github.com/matrix-org/synapse/issues/8890))
+- Add support for allowing users to pick their own user ID during a single-sign-on login. ([\#8897](https://github.com/matrix-org/synapse/issues/8897), [\#8900](https://github.com/matrix-org/synapse/issues/8900), [\#8911](https://github.com/matrix-org/synapse/issues/8911), [\#8938](https://github.com/matrix-org/synapse/issues/8938), [\#8941](https://github.com/matrix-org/synapse/issues/8941), [\#8942](https://github.com/matrix-org/synapse/issues/8942), [\#8951](https://github.com/matrix-org/synapse/issues/8951))
+- Add an `email.invite_client_location` configuration option to send a web client location to the invite endpoint on the identity server which allows customisation of the email template. ([\#8930](https://github.com/matrix-org/synapse/issues/8930))
+- The search term in the list room and list user Admin APIs is now treated as case-insensitive. ([\#8931](https://github.com/matrix-org/synapse/issues/8931))
+- Apply an IP range blacklist to push and key revocation requests. ([\#8821](https://github.com/matrix-org/synapse/issues/8821), [\#8870](https://github.com/matrix-org/synapse/issues/8870), [\#8954](https://github.com/matrix-org/synapse/issues/8954))
+- Add an option to allow re-use of user-interactive authentication sessions for a period of time. ([\#8970](https://github.com/matrix-org/synapse/issues/8970))
+- Allow running the redact endpoint on workers. ([\#8994](https://github.com/matrix-org/synapse/issues/8994))
+
+
+Bugfixes
+--------
+
+- Fix bug where we might not correctly calculate the current state for rooms with multiple extremities. ([\#8827](https://github.com/matrix-org/synapse/issues/8827))
+- Fix a long-standing bug in the register admin endpoint (`/_synapse/admin/v1/register`) when the `mac` field was not provided. The endpoint now properly returns a 400 error. Contributed by @edwargix. ([\#8837](https://github.com/matrix-org/synapse/issues/8837))
+- Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password. ([\#8858](https://github.com/matrix-org/synapse/issues/8858))
+- Fix a longstanding bug where a 500 error would be returned if the `Content-Length` header was not provided to the upload media resource. ([\#8862](https://github.com/matrix-org/synapse/issues/8862))
+- Add additional validation to pusher URLs to be compliant with the specification. ([\#8865](https://github.com/matrix-org/synapse/issues/8865))
+- Fix the error code that is returned when a user tries to register on a homeserver on which new-user registration has been disabled. ([\#8867](https://github.com/matrix-org/synapse/issues/8867))
+- Fix a bug where `PUT /_synapse/admin/v2/users/<user_id>` failed to create a new user when `avatar_url` is specified. Bug introduced in Synapse v1.9.0. ([\#8872](https://github.com/matrix-org/synapse/issues/8872))
+- Fix a 500 error when attempting to preview an empty HTML file. ([\#8883](https://github.com/matrix-org/synapse/issues/8883))
+- Fix occasional deadlock when handling SIGHUP. ([\#8918](https://github.com/matrix-org/synapse/issues/8918))
+- Fix login API to not ratelimit application services that have ratelimiting disabled. ([\#8920](https://github.com/matrix-org/synapse/issues/8920))
+- Fix bug where we ratelimited auto joining of rooms on registration (using `auto_join_rooms` config). ([\#8921](https://github.com/matrix-org/synapse/issues/8921))
+- Fix a bug where deactivated users appeared in the user directory when their profile information was updated. ([\#8933](https://github.com/matrix-org/synapse/issues/8933), [\#8964](https://github.com/matrix-org/synapse/issues/8964))
+- Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file. ([\#8937](https://github.com/matrix-org/synapse/issues/8937))
+- Fix a bug where 500 errors would be returned if the `m.room_history_visibility` event had invalid content. ([\#8945](https://github.com/matrix-org/synapse/issues/8945))
+- Fix a bug causing common English words to not be considered for a user directory search. ([\#8959](https://github.com/matrix-org/synapse/issues/8959))
+- Fix bug where application services couldn't register new ghost users if the server had reached its MAU limit. ([\#8962](https://github.com/matrix-org/synapse/issues/8962))
+- Fix a long-standing bug where a `m.image` event without a `url` would cause errors on push. ([\#8965](https://github.com/matrix-org/synapse/issues/8965))
+- Fix a small bug in v2 state resolution algorithm, which could also cause performance issues for rooms with large numbers of power levels. ([\#8971](https://github.com/matrix-org/synapse/issues/8971))
+- Add validation to the `sendToDevice` API to raise a missing parameters error instead of a 500 error. ([\#8975](https://github.com/matrix-org/synapse/issues/8975))
+- Add validation of group IDs to raise a 400 error instead of a 500 eror. ([\#8977](https://github.com/matrix-org/synapse/issues/8977))
+
+
+Improved Documentation
+----------------------
+
+- Fix the "Event persist rate" section of the included grafana dashboard by adding missing prometheus rules. ([\#8802](https://github.com/matrix-org/synapse/issues/8802))
+- Combine related media admin API docs. ([\#8839](https://github.com/matrix-org/synapse/issues/8839))
+- Fix an error in the documentation for the SAML username mapping provider. ([\#8873](https://github.com/matrix-org/synapse/issues/8873))
+- Clarify comments around template directories in `sample_config.yaml`. ([\#8891](https://github.com/matrix-org/synapse/issues/8891))
+- Move instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by @fossterer. ([\#8987](https://github.com/matrix-org/synapse/issues/8987))
+- Update the example value of `group_creation_prefix` in the sample configuration. ([\#8992](https://github.com/matrix-org/synapse/issues/8992))
+- Link the Synapse developer room to the development section in the docs. ([\#9002](https://github.com/matrix-org/synapse/issues/9002))
+
+
+Deprecations and Removals
+-------------------------
+
+- Deprecate Shutdown Room and Purge Room Admin APIs. ([\#8829](https://github.com/matrix-org/synapse/issues/8829))
+
+
+Internal Changes
+----------------
+
+- Properly store the mapping of external ID to Matrix ID for CAS users. ([\#8856](https://github.com/matrix-org/synapse/issues/8856), [\#8958](https://github.com/matrix-org/synapse/issues/8958))
+- Remove some unnecessary stubbing from unit tests. ([\#8861](https://github.com/matrix-org/synapse/issues/8861))
+- Remove unused `FakeResponse` class from unit tests. ([\#8864](https://github.com/matrix-org/synapse/issues/8864))
+- Pass `room_id` to `get_auth_chain_difference`. ([\#8879](https://github.com/matrix-org/synapse/issues/8879))
+- Add type hints to push module. ([\#8880](https://github.com/matrix-org/synapse/issues/8880), [\#8882](https://github.com/matrix-org/synapse/issues/8882), [\#8901](https://github.com/matrix-org/synapse/issues/8901), [\#8940](https://github.com/matrix-org/synapse/issues/8940), [\#8943](https://github.com/matrix-org/synapse/issues/8943), [\#9020](https://github.com/matrix-org/synapse/issues/9020))
+- Simplify logic for handling user-interactive-auth via single-sign-on servers. ([\#8881](https://github.com/matrix-org/synapse/issues/8881))
+- Skip the SAML tests if the requirements (`pysaml2` and `xmlsec1`) aren't available. ([\#8905](https://github.com/matrix-org/synapse/issues/8905))
+- Fix multiarch docker image builds. ([\#8906](https://github.com/matrix-org/synapse/issues/8906))
+- Don't publish `latest` docker image until all archs are built. ([\#8909](https://github.com/matrix-org/synapse/issues/8909))
+- Various clean-ups to the structured logging and logging context code. ([\#8916](https://github.com/matrix-org/synapse/issues/8916), [\#8935](https://github.com/matrix-org/synapse/issues/8935))
+- Automatically drop stale forward-extremities under some specific conditions. ([\#8929](https://github.com/matrix-org/synapse/issues/8929))
+- Refactor test utilities for injecting HTTP requests. ([\#8946](https://github.com/matrix-org/synapse/issues/8946))
+- Add a maximum size of 50 kilobytes to .well-known lookups. ([\#8950](https://github.com/matrix-org/synapse/issues/8950))
+- Fix bug in `generate_log_config` script which made it write empty files. ([\#8952](https://github.com/matrix-org/synapse/issues/8952))
+- Clean up tox.ini file; disable coverage checking for non-test runs. ([\#8963](https://github.com/matrix-org/synapse/issues/8963))
+- Add type hints to the admin and room list handlers. ([\#8973](https://github.com/matrix-org/synapse/issues/8973))
+- Add type hints to the receipts and user directory handlers. ([\#8976](https://github.com/matrix-org/synapse/issues/8976))
+- Drop the unused `local_invites` table. ([\#8979](https://github.com/matrix-org/synapse/issues/8979))
+- Add type hints to the base storage code. ([\#8980](https://github.com/matrix-org/synapse/issues/8980))
+- Support using PyJWT v2.0.0 in the test suite. ([\#8986](https://github.com/matrix-org/synapse/issues/8986))
+- Fix `tests.federation.transport.RoomDirectoryFederationTests` and ensure it runs in CI. ([\#8998](https://github.com/matrix-org/synapse/issues/8998))
+- Add type hints to the crypto module. ([\#8999](https://github.com/matrix-org/synapse/issues/8999))
+
+
 Synapse 1.24.0 (2020-12-09)
 ===========================
 
@@ -44,6 +400,58 @@ Internal Changes
 - Add a maximum version for pysaml2 on Python 3.5. ([\#8898](https://github.com/matrix-org/synapse/issues/8898))
 
 
+Synapse 1.23.1 (2020-12-09)
+===========================
+
+Due to the two security issues highlighted below, server administrators are
+encouraged to update Synapse. We are not aware of these vulnerabilities being
+exploited in the wild.
+
+Security advisory
+-----------------
+
+The following issues are fixed in v1.23.1 and v1.24.0.
+
+- There is a denial of service attack
+  ([CVE-2020-26257](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26257))
+  against the federation APIs in which future events will not be correctly sent
+  to other servers over federation. This affects all servers that participate in
+  open federation. (Fixed in [#8776](https://github.com/matrix-org/synapse/pull/8776)).
+
+- Synapse may be affected by OpenSSL
+  [CVE-2020-1971](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-1971).
+  Synapse administrators should ensure that they have the latest versions of
+  the cryptography Python package installed.
+
+To upgrade Synapse along with the cryptography package:
+
+* Administrators using the [`matrix.org` Docker
+  image](https://hub.docker.com/r/matrixdotorg/synapse/) or the [Debian/Ubuntu
+  packages from
+  `matrix.org`](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#matrixorg-packages)
+  should ensure that they have version 1.24.0 or 1.23.1 installed: these images include
+  the updated packages.
+* Administrators who have [installed Synapse from
+  source](https://github.com/matrix-org/synapse/blob/master/INSTALL.md#installing-from-source)
+  should upgrade the cryptography package within their virtualenv by running:
+  ```sh
+  <path_to_virtualenv>/bin/pip install 'cryptography>=3.3'
+  ```
+* Administrators who have installed Synapse from distribution packages should
+  consult the information from their distributions.
+
+Bugfixes
+--------
+
+- Fix a bug in some federation APIs which could lead to unexpected behaviour if different parameters were set in the URI and the request body. ([\#8776](https://github.com/matrix-org/synapse/issues/8776))
+
+
+Internal Changes
+----------------
+
+- Add a maximum version for pysaml2 on Python 3.5. ([\#8898](https://github.com/matrix-org/synapse/issues/8898))
+
+
 Synapse 1.24.0rc2 (2020-12-04)
 ==============================
 
diff --git a/INSTALL.md b/INSTALL.md
index eaeb690092..d405d9fe55 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -1,19 +1,44 @@
-- [Choosing your server name](#choosing-your-server-name)
-- [Picking a database engine](#picking-a-database-engine)
-- [Installing Synapse](#installing-synapse)
-  - [Installing from source](#installing-from-source)
-    - [Platform-Specific Instructions](#platform-specific-instructions)
-  - [Prebuilt packages](#prebuilt-packages)
-- [Setting up Synapse](#setting-up-synapse)
-  - [TLS certificates](#tls-certificates)
-  - [Client Well-Known URI](#client-well-known-uri)
-  - [Email](#email)
-  - [Registering a user](#registering-a-user)
-  - [Setting up a TURN server](#setting-up-a-turn-server)
-  - [URL previews](#url-previews)
-- [Troubleshooting Installation](#troubleshooting-installation)
-
-# Choosing your server name
+# Installation Instructions
+
+There are 3 steps to follow under **Installation Instructions**.
+
+- [Installation Instructions](#installation-instructions)
+  - [Choosing your server name](#choosing-your-server-name)
+  - [Installing Synapse](#installing-synapse)
+    - [Installing from source](#installing-from-source)
+      - [Platform-Specific Instructions](#platform-specific-instructions)
+        - [Debian/Ubuntu/Raspbian](#debianubunturaspbian)
+        - [ArchLinux](#archlinux)
+        - [CentOS/Fedora](#centosfedora)
+        - [macOS](#macos)
+        - [OpenSUSE](#opensuse)
+        - [OpenBSD](#openbsd)
+        - [Windows](#windows)
+    - [Prebuilt packages](#prebuilt-packages)
+      - [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks)
+      - [Debian/Ubuntu](#debianubuntu)
+        - [Matrix.org packages](#matrixorg-packages)
+        - [Downstream Debian packages](#downstream-debian-packages)
+        - [Downstream Ubuntu packages](#downstream-ubuntu-packages)
+      - [Fedora](#fedora)
+      - [OpenSUSE](#opensuse-1)
+      - [SUSE Linux Enterprise Server](#suse-linux-enterprise-server)
+      - [ArchLinux](#archlinux-1)
+      - [Void Linux](#void-linux)
+      - [FreeBSD](#freebsd)
+      - [OpenBSD](#openbsd-1)
+      - [NixOS](#nixos)
+  - [Setting up Synapse](#setting-up-synapse)
+    - [Using PostgreSQL](#using-postgresql)
+    - [TLS certificates](#tls-certificates)
+    - [Client Well-Known URI](#client-well-known-uri)
+    - [Email](#email)
+    - [Registering a user](#registering-a-user)
+    - [Setting up a TURN server](#setting-up-a-turn-server)
+    - [URL previews](#url-previews)
+    - [Troubleshooting Installation](#troubleshooting-installation)
+
+## Choosing your server name
 
 It is important to choose the name for your server before you install Synapse,
 because it cannot be changed later.
@@ -29,28 +54,9 @@ that your email address is probably `user@example.com` rather than
 `user@email.example.com`) - but doing so may require more advanced setup: see
 [Setting up Federation](docs/federate.md).
 
-# Picking a database engine
+## Installing Synapse
 
-Synapse offers two database engines:
- * [PostgreSQL](https://www.postgresql.org)
- * [SQLite](https://sqlite.org/)
-
-Almost all installations should opt to use PostgreSQL. Advantages include:
-
-* significant performance improvements due to the superior threading and
-  caching model, smarter query optimiser
-* allowing the DB to be run on separate hardware
-
-For information on how to install and use PostgreSQL, please see
-[docs/postgres.md](docs/postgres.md)
-
-By default Synapse uses SQLite and in doing so trades performance for convenience.
-SQLite is only recommended in Synapse for testing purposes or for servers with
-light workloads.
-
-# Installing Synapse
-
-## Installing from source
+### Installing from source
 
 (Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).)
 
@@ -68,7 +74,7 @@ these on various platforms.
 
 To install the Synapse homeserver run:
 
-```
+```sh
 mkdir -p ~/synapse
 virtualenv -p python3 ~/synapse/env
 source ~/synapse/env/bin/activate
@@ -85,7 +91,7 @@ prefer.
 This Synapse installation can then be later upgraded by using pip again with the
 update flag:
 
-```
+```sh
 source ~/synapse/env/bin/activate
 pip install -U matrix-synapse
 ```
@@ -93,7 +99,7 @@ pip install -U matrix-synapse
 Before you can start Synapse, you will need to generate a configuration
 file. To do this, run (in your virtualenv, as before):
 
-```
+```sh
 cd ~/synapse
 python -m synapse.app.homeserver \
     --server-name my.domain.name \
@@ -111,45 +117,43 @@ wise to back them up somewhere safe. (If, for whatever reason, you do need to
 change your homeserver's keys, you may find that other homeserver have the
 old key cached. If you update the signing key, you should change the name of the
 key in the `<server name>.signing.key` file (the second word) to something
-different. See the
-[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
-for more information on key management).
+different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) for more information on key management).
 
 To actually run your new homeserver, pick a working directory for Synapse to
 run (e.g. `~/synapse`), and:
 
-```
+```sh
 cd ~/synapse
 source env/bin/activate
 synctl start
 ```
 
-### Platform-Specific Instructions
+#### Platform-Specific Instructions
 
-#### Debian/Ubuntu/Raspbian
+##### Debian/Ubuntu/Raspbian
 
 Installing prerequisites on Ubuntu or Debian:
 
-```
-sudo apt-get install build-essential python3-dev libffi-dev \
+```sh
+sudo apt install build-essential python3-dev libffi-dev \
                      python3-pip python3-setuptools sqlite3 \
                      libssl-dev virtualenv libjpeg-dev libxslt1-dev
 ```
 
-#### ArchLinux
+##### ArchLinux
 
 Installing prerequisites on ArchLinux:
 
-```
+```sh
 sudo pacman -S base-devel python python-pip \
                python-setuptools python-virtualenv sqlite3
 ```
 
-#### CentOS/Fedora
+##### CentOS/Fedora
 
 Installing prerequisites on CentOS 8 or Fedora>26:
 
-```
+```sh
 sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
                  libwebp-devel tk-devel redhat-rpm-config \
                  python3-virtualenv libffi-devel openssl-devel
@@ -158,7 +162,7 @@ sudo dnf groupinstall "Development Tools"
 
 Installing prerequisites on CentOS 7 or Fedora<=25:
 
-```
+```sh
 sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
                  lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
                  python3-virtualenv libffi-devel openssl-devel
@@ -170,11 +174,11 @@ uses SQLite 3.7. You may be able to work around this by installing a more
 recent SQLite version, but it is recommended that you instead use a Postgres
 database: see [docs/postgres.md](docs/postgres.md).
 
-#### macOS
+##### macOS
 
 Installing prerequisites on macOS:
 
-```
+```sh
 xcode-select --install
 sudo easy_install pip
 sudo pip install virtualenv
@@ -184,22 +188,23 @@ brew install pkg-config libffi
 On macOS Catalina (10.15) you may need to explicitly install OpenSSL
 via brew and inform `pip` about it so that `psycopg2` builds:
 
-```
+```sh
 brew install openssl@1.1
-export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
+export LDFLAGS="-L/usr/local/opt/openssl/lib"
+export CPPFLAGS="-I/usr/local/opt/openssl/include"
 ```
 
-#### OpenSUSE
+##### OpenSUSE
 
 Installing prerequisites on openSUSE:
 
-```
+```sh
 sudo zypper in -t pattern devel_basis
 sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
                python-devel libffi-devel libopenssl-devel libjpeg62-devel
 ```
 
-#### OpenBSD
+##### OpenBSD
 
 A port of Synapse is available under `net/synapse`. The filesystem
 underlying the homeserver directory (defaults to `/var/synapse`) has to be
@@ -213,73 +218,72 @@ mounted with `wxallowed` (cf. `mount(8)`).
 Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a
 default OpenBSD installation is mounted with `wxallowed`):
 
-```
+```sh
 doas mkdir /usr/local/pobj_wxallowed
 ```
 
 Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are
 configured in `/etc/mk.conf`:
 
-```
+```sh
 doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed
 ```
 
 Setting the `WRKOBJDIR` for building python:
 
-```
+```sh
 echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed  \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf
 ```
 
 Building Synapse:
 
-```
+```sh
 cd /usr/ports/net/synapse
 make install
 ```
 
-#### Windows
+##### Windows
 
 If you wish to run or develop Synapse on Windows, the Windows Subsystem For
 Linux provides a Linux environment on Windows 10 which is capable of using the
 Debian, Fedora, or source installation methods. More information about WSL can
-be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
-Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
+be found at <https://docs.microsoft.com/en-us/windows/wsl/install-win10> for
+Windows 10 and <https://docs.microsoft.com/en-us/windows/wsl/install-on-server>
 for Windows Server.
 
-## Prebuilt packages
+### Prebuilt packages
 
 As an alternative to installing from source, prebuilt packages are available
 for a number of platforms.
 
-### Docker images and Ansible playbooks
+#### Docker images and Ansible playbooks
 
-There is an offical synapse image available at
-https://hub.docker.com/r/matrixdotorg/synapse which can be used with
+There is an official synapse image available at
+<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
 the docker-compose file available at [contrib/docker](contrib/docker). Further
 information on this including configuration options is available in the README
 on hub.docker.com.
 
 Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
 Dockerfile to automate a synapse server in a single Docker image, at
-https://hub.docker.com/r/avhost/docker-matrix/tags/
+<https://hub.docker.com/r/avhost/docker-matrix/tags/>
 
 Slavi Pantaleev has created an Ansible playbook,
 which installs the offical Docker image of Matrix Synapse
 along with many other Matrix-related services (Postgres database, Element, coturn,
 ma1sd, SSL support, etc.).
 For more details, see
-https://github.com/spantaleev/matrix-docker-ansible-deploy
-
+<https://github.com/spantaleev/matrix-docker-ansible-deploy>
 
-### Debian/Ubuntu
+#### Debian/Ubuntu
 
-#### Matrix.org packages
+##### Matrix.org packages
 
 Matrix.org provides Debian/Ubuntu packages of the latest stable version of
-Synapse via https://packages.matrix.org/debian/. They are available for Debian
+Synapse via <https://packages.matrix.org/debian/>. They are available for Debian
 9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them:
 
-```
+```sh
 sudo apt install -y lsb-release wget apt-transport-https
 sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg
 echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main" |
@@ -299,7 +303,7 @@ The fingerprint of the repository signing key (as shown by `gpg
 /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is
 `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`.
 
-#### Downstream Debian packages
+##### Downstream Debian packages
 
 We do not recommend using the packages from the default Debian `buster`
 repository at this time, as they are old and suffer from known security
@@ -311,49 +315,49 @@ for information on how to use backports.
 If you are using Debian `sid` or testing, Synapse is available in the default
 repositories and it should be possible to install it simply with:
 
-```
+```sh
 sudo apt install matrix-synapse
 ```
 
-#### Downstream Ubuntu packages
+##### Downstream Ubuntu packages
 
 We do not recommend using the packages in the default Ubuntu repository
 at this time, as they are old and suffer from known security vulnerabilities.
 The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
 
-### Fedora
+#### Fedora
 
 Synapse is in the Fedora repositories as `matrix-synapse`:
 
-```
+```sh
 sudo dnf install matrix-synapse
 ```
 
 Oleg Girko provides Fedora RPMs at
-https://obs.infoserver.lv/project/monitor/matrix-synapse
+<https://obs.infoserver.lv/project/monitor/matrix-synapse>
 
-### OpenSUSE
+#### OpenSUSE
 
 Synapse is in the OpenSUSE repositories as `matrix-synapse`:
 
-```
+```sh
 sudo zypper install matrix-synapse
 ```
 
-### SUSE Linux Enterprise Server
+#### SUSE Linux Enterprise Server
 
 Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at
-https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/
+<https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/>
 
-### ArchLinux
+#### ArchLinux
 
 The quickest way to get up and running with ArchLinux is probably with the community package
-https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in most of
+<https://www.archlinux.org/packages/community/any/matrix-synapse/>, which should pull in most of
 the necessary dependencies.
 
 pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):
 
-```
+```sh
 sudo pip install --upgrade pip
 ```
 
@@ -362,28 +366,28 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
 compile it under the right architecture. (This should not be needed if
 installing under virtualenv):
 
-```
+```sh
 sudo pip uninstall py-bcrypt
 sudo pip install py-bcrypt
 ```
 
-### Void Linux
+#### Void Linux
 
 Synapse can be found in the void repositories as 'synapse':
 
-```
+```sh
 xbps-install -Su
 xbps-install -S synapse
 ```
 
-### FreeBSD
+#### FreeBSD
 
 Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
 
- - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
- - Packages: `pkg install py37-matrix-synapse`
+- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
+- Packages: `pkg install py37-matrix-synapse`
 
-### OpenBSD
+#### OpenBSD
 
 As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem
 underlying the homeserver directory (defaults to `/var/synapse`) has to be
@@ -392,20 +396,35 @@ and mounting it to `/var/synapse` should be taken into consideration.
 
 Installing Synapse:
 
-```
+```sh
 doas pkg_add synapse
 ```
 
-### NixOS
+#### NixOS
 
 Robin Lambertz has packaged Synapse for NixOS at:
-https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix
+<https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix>
 
-# Setting up Synapse
+## Setting up Synapse
 
 Once you have installed synapse as above, you will need to configure it.
 
-## TLS certificates
+### Using PostgreSQL
+
+By default Synapse uses [SQLite](https://sqlite.org/) and in doing so trades performance for convenience.
+SQLite is only recommended in Synapse for testing purposes or for servers with
+very light workloads.
+
+Almost all installations should opt to use [PostgreSQL](https://www.postgresql.org). Advantages include:
+
+- significant performance improvements due to the superior threading and
+  caching model, smarter query optimiser
+- allowing the DB to be run on separate hardware
+
+For information on how to install and use PostgreSQL in Synapse, please see
+[docs/postgres.md](docs/postgres.md)
+
+### TLS certificates
 
 The default configuration exposes a single HTTP port on the local
 interface: `http://localhost:8008`. It is suitable for local testing,
@@ -419,19 +438,19 @@ The recommended way to do so is to set up a reverse proxy on port
 Alternatively, you can configure Synapse to expose an HTTPS port. To do
 so, you will need to edit `homeserver.yaml`, as follows:
 
-* First, under the `listeners` section, uncomment the configuration for the
+- First, under the `listeners` section, uncomment the configuration for the
   TLS-enabled listener. (Remove the hash sign (`#`) at the start of
   each line). The relevant lines are like this:
 
-  ```
-    - port: 8448
-      type: http
-      tls: true
-      resources:
-        - names: [client, federation]
+```yaml
+  - port: 8448
+    type: http
+    tls: true
+    resources:
+      - names: [client, federation]
   ```
 
-* You will also need to uncomment the `tls_certificate_path` and
+- You will also need to uncomment the `tls_certificate_path` and
   `tls_private_key_path` lines under the `TLS` section. You will need to manage
   provisioning of these certificates yourself — Synapse had built-in ACME
   support, but the ACMEv1 protocol Synapse implements is deprecated, not
@@ -446,7 +465,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
 For a more detailed guide to configuring your server for federation, see
 [federate.md](docs/federate.md).
 
-## Client Well-Known URI
+### Client Well-Known URI
 
 Setting up the client Well-Known URI is optional but if you set it up, it will
 allow users to enter their full username (e.g. `@user:<server_name>`) into clients
@@ -457,7 +476,7 @@ about the actual homeserver URL you are using.
 The URL `https://<server_name>/.well-known/matrix/client` should return JSON in
 the following format.
 
-```
+```json
 {
   "m.homeserver": {
     "base_url": "https://<matrix.example.com>"
@@ -467,7 +486,7 @@ the following format.
 
 It can optionally contain identity server information as well.
 
-```
+```json
 {
   "m.homeserver": {
     "base_url": "https://<matrix.example.com>"
@@ -484,7 +503,8 @@ Cross-Origin Resource Sharing (CORS) headers. A recommended value would be
 view it.
 
 In nginx this would be something like:
-```
+
+```nginx
 location /.well-known/matrix/client {
     return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
     default_type application/json;
@@ -497,11 +517,11 @@ correctly. `public_baseurl` should be set to the URL that clients will use to
 connect to your server. This is the same URL you put for the `m.homeserver`
 `base_url` above.
 
-```
+```yaml
 public_baseurl: "https://<matrix.example.com>"
 ```
 
-## Email
+### Email
 
 It is desirable for Synapse to have the capability to send email. This allows
 Synapse to send password reset emails, send verifications when an email address
@@ -516,7 +536,7 @@ and `notif_from` fields filled out.  You may also need to set `smtp_user`,
 If email is not configured, password reset, registration and notifications via
 email will be disabled.
 
-## Registering a user
+### Registering a user
 
 The easiest way to create a new user is to do so from a client like [Element](https://element.io/).
 
@@ -524,7 +544,7 @@ Alternatively you can do so from the command line if you have installed via pip.
 
 This can be done as follows:
 
-```
+```sh
 $ source ~/synapse/env/bin/activate
 $ synctl start # if not already running
 $ register_new_matrix_user -c homeserver.yaml http://localhost:8008
@@ -542,12 +562,12 @@ value is generated by `--generate-config`), but it should be kept secret, as
 anyone with knowledge of it can register users, including admin accounts,
 on your server even if `enable_registration` is `false`.
 
-## Setting up a TURN server
+### Setting up a TURN server
 
 For reliable VoIP calls to be routed via this homeserver, you MUST configure
 a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
 
-## URL previews
+### URL previews
 
 Synapse includes support for previewing URLs, which is disabled by default.  To
 turn it on you must enable the `url_preview_enabled: True` config parameter
@@ -557,19 +577,18 @@ This is critical from a security perspective to stop arbitrary Matrix users
 spidering 'internal' URLs on your network. At the very least we recommend that
 your loopback and RFC1918 IP addresses are blacklisted.
 
-This also requires the optional `lxml` and `netaddr` python dependencies to be
-installed. This in turn requires the `libxml2` library to be available - on
-Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for
-your OS.
+This also requires the optional `lxml` python dependency to be  installed. This
+in turn requires the `libxml2` library to be available - on  Debian/Ubuntu this
+means `apt-get install libxml2-dev`, or equivalent for your OS.
 
-# Troubleshooting Installation
+### Troubleshooting Installation
 
 `pip` seems to leak *lots* of memory during installation. For instance, a Linux
 host with 512MB of RAM may run out of memory whilst installing Twisted. If this
 happens, you will have to individually install the dependencies which are
 failing, e.g.:
 
-```
+```sh
 pip install twisted
 ```
 
diff --git a/README.rst b/README.rst
index d724cf97da..d872b11f57 100644
--- a/README.rst
+++ b/README.rst
@@ -243,6 +243,8 @@ Then update the ``users`` table in the database::
 Synapse Development
 ===================
 
+Join our developer community on Matrix: `#synapse-dev:matrix.org <https://matrix.to/#/#synapse-dev:matrix.org>`_
+
 Before setting up a development environment for synapse, make sure you have the
 system dependencies (such as the python header files) installed - see
 `Installing from source <INSTALL.md#installing-from-source>`_.
@@ -278,6 +280,27 @@ differ)::
 
     PASSED (skips=15, successes=1322)
 
+We recommend using the demo which starts 3 federated instances running on ports `8080` - `8082`
+
+    ./demo/start.sh
+
+(to stop, you can use `./demo/stop.sh`)
+
+If you just want to start a single instance of the app and run it directly::
+
+    # Create the homeserver.yaml config once
+    python -m synapse.app.homeserver \
+      --server-name my.domain.name \
+      --config-path homeserver.yaml \
+      --generate-config \
+      --report-stats=[yes|no]
+
+    # Start the app
+    python -m synapse.app.homeserver --config-path homeserver.yaml
+
+
+
+
 Running the Integration Tests
 =============================
 
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 6825b567e9..22edfe0d60 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -5,6 +5,16 @@ Before upgrading check if any special steps are required to upgrade from the
 version you currently have installed to the current version of Synapse. The extra
 instructions that may be required are listed later in this document.
 
+* Check that your versions of Python and PostgreSQL are still supported.
+
+  Synapse follows upstream lifecycles for `Python`_ and `PostgreSQL`_, and
+  removes support for versions which are no longer maintained.
+
+  The website https://endoflife.date also offers convenient summaries.
+
+  .. _Python: https://devguide.python.org/devcycle/#end-of-life-branches
+  .. _PostgreSQL: https://www.postgresql.org/support/versioning/
+
 * If Synapse was installed using `prebuilt packages
   <INSTALL.md#prebuilt-packages>`_, you will need to follow the normal process
   for upgrading those packages.
@@ -75,6 +85,141 @@ for example:
      wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
      dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
 
+Upgrading to v1.27.0
+====================
+
+Changes to callback URI for OAuth2 / OpenID Connect
+---------------------------------------------------
+
+This version changes the URI used for callbacks from OAuth2 identity providers. If
+your server is configured for single sign-on via an OpenID Connect or OAuth2 identity
+provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback``
+to the list of permitted "redirect URIs" at the identity provider.
+
+See `docs/openid.md <docs/openid.md>`_ for more information on setting up OpenID
+Connect.
+
+(Note: a similar change is being made for SAML2; in this case the old URI
+``[synapse public baseurl]/_matrix/saml2`` is being deprecated, but will continue to
+work, so no immediate changes are required for existing installations.)
+
+Changes to HTML templates
+-------------------------
+
+The HTML templates for SSO and email notifications now have `Jinja2's autoescape <https://jinja.palletsprojects.com/en/2.11.x/api/#autoescaping>`_
+enabled for files ending in ``.html``, ``.htm``, and ``.xml``. If you have customised
+these templates and see issues when viewing them you might need to update them.
+It is expected that most configurations will need no changes.
+
+If you have customised the templates *names* for these templates, it is recommended
+to verify they end in ``.html`` to ensure autoescape is enabled.
+
+The above applies to the following templates:
+
+* ``add_threepid.html``
+* ``add_threepid_failure.html``
+* ``add_threepid_success.html``
+* ``notice_expiry.html``
+* ``notice_expiry.html``
+* ``notif_mail.html`` (which, by default, includes ``room.html`` and ``notif.html``)
+* ``password_reset.html``
+* ``password_reset_confirmation.html``
+* ``password_reset_failure.html``
+* ``password_reset_success.html``
+* ``registration.html``
+* ``registration_failure.html``
+* ``registration_success.html``
+* ``sso_account_deactivated.html``
+* ``sso_auth_bad_user.html``
+* ``sso_auth_confirm.html``
+* ``sso_auth_success.html``
+* ``sso_error.html``
+* ``sso_login_idp_picker.html``
+* ``sso_redirect_confirm.html``
+
+Upgrading to v1.26.0
+====================
+
+Rolling back to v1.25.0 after a failed upgrade
+----------------------------------------------
+
+v1.26.0 includes a lot of large changes. If something problematic occurs, you
+may want to roll-back to a previous version of Synapse. Because v1.26.0 also
+includes a new database schema version, reverting that version is also required
+alongside the generic rollback instructions mentioned above. In short, to roll
+back to v1.25.0 you need to:
+
+1. Stop the server
+2. Decrease the schema version in the database:
+
+   .. code:: sql
+
+      UPDATE schema_version SET version = 58;
+
+3. Delete the ignored users & chain cover data:
+
+   .. code:: sql
+
+      DROP TABLE IF EXISTS ignored_users;
+      UPDATE rooms SET has_auth_chain_index = false;
+
+   For PostgreSQL run:
+
+   .. code:: sql
+
+      TRUNCATE event_auth_chain_links;
+      TRUNCATE event_auth_chains;
+
+   For SQLite run:
+
+   .. code:: sql
+
+      DELETE FROM event_auth_chain_links;
+      DELETE FROM event_auth_chains;
+
+4. Mark the deltas as not run (so they will re-run on upgrade).
+
+   .. code:: sql
+
+      DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/01ignored_user.py";
+      DELETE FROM applied_schema_deltas WHERE version = 59 AND file = "59/06chain_cover_index.sql";
+
+5. Downgrade Synapse by following the instructions for your installation method
+   in the "Rolling back to older versions" section above.
+
+Upgrading to v1.25.0
+====================
+
+Last release supporting Python 3.5
+----------------------------------
+
+This is the last release of Synapse which guarantees support with Python 3.5,
+which passed its upstream End of Life date several months ago.
+
+We will attempt to maintain support through March 2021, but without guarantees.
+
+In the future, Synapse will follow upstream schedules for ending support of
+older versions of Python and PostgreSQL. Please upgrade to at least Python 3.6
+and PostgreSQL 9.6 as soon as possible.
+
+Blacklisting IP ranges
+----------------------
+
+Synapse v1.25.0 includes new settings, ``ip_range_blacklist`` and
+``ip_range_whitelist``, for controlling outgoing requests from Synapse for federation,
+identity servers, push, and for checking key validity for third-party invite events.
+The previous setting, ``federation_ip_range_blacklist``, is deprecated. The new
+``ip_range_blacklist`` defaults to private IP ranges if it is not defined.
+
+If you have never customised ``federation_ip_range_blacklist`` it is recommended
+that you remove that setting.
+
+If you have customised ``federation_ip_range_blacklist`` you should update the
+setting name to ``ip_range_blacklist``.
+
+If you have a custom push server that is reached via private IP space you may
+need to customise ``ip_range_blacklist`` or ``ip_range_whitelist``.
+
 Upgrading to v1.24.0
 ====================
 
@@ -105,7 +250,7 @@ shown below:
 
           return {"localpart": localpart}
 
-Removal historical Synapse Admin API 
+Removal historical Synapse Admin API
 ------------------------------------
 
 Historically, the Synapse Admin API has been accessible under:
diff --git a/contrib/prometheus/synapse-v2.rules b/contrib/prometheus/synapse-v2.rules
index 6ccca2daaf..7e405bf7f0 100644
--- a/contrib/prometheus/synapse-v2.rules
+++ b/contrib/prometheus/synapse-v2.rules
@@ -58,3 +58,21 @@ groups:
     labels:
       type: "PDU"
     expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
+
+  - record: synapse_storage_events_persisted_by_source_type
+    expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_type="remote"})
+    labels:
+      type: remote
+  - record: synapse_storage_events_persisted_by_source_type
+    expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity="*client*",origin_type="local"})
+    labels:
+      type: local
+  - record: synapse_storage_events_persisted_by_source_type
+    expr: sum without(type, origin_type, origin_entity) (synapse_storage_events_persisted_events_sep{origin_entity!="*client*",origin_type="local"})
+    labels:
+      type: bridges
+  - record: synapse_storage_events_persisted_by_event_type
+    expr: sum without(origin_entity, origin_type) (synapse_storage_events_persisted_events_sep)
+  - record: synapse_storage_events_persisted_by_origin
+    expr: sum without(type) (synapse_storage_events_persisted_events_sep)
+
diff --git a/debian/build_virtualenv b/debian/build_virtualenv
index cbdde93f96..cf19084a9f 100755
--- a/debian/build_virtualenv
+++ b/debian/build_virtualenv
@@ -33,11 +33,13 @@ esac
 # Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather
 # than the 2/3 compatible `virtualenv`.
 
+# Pin pip to 20.3.4 to fix breakage in 21.0 on py3.5 (xenial)
+
 dh_virtualenv \
     --install-suffix "matrix-synapse" \
     --builtin-venv \
     --python "$SNAKE" \
-    --upgrade-pip \
+    --upgrade-pip-to="20.3.4" \
     --preinstall="lxml" \
     --preinstall="mock" \
     --extra-pip-arg="--no-cache-dir" \
diff --git a/debian/changelog b/debian/changelog
index 9f47d12b7e..aa83d4e13e 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,9 +1,46 @@
+matrix-synapse-py3 (1.27.0) stable; urgency=medium
+
+  [ Dan Callahan ]
+  * Fix build on Ubuntu 16.04 LTS (Xenial).
+
+  [ Synapse Packaging team ]
+  * New synapse release 1.27.0.
+
+ -- Synapse Packaging team <packages@matrix.org>  Tue, 16 Feb 2021 13:11:28 +0000
+
+matrix-synapse-py3 (1.26.0) stable; urgency=medium
+
+  [ Richard van der Hoff ]
+  * Remove dependency on `python3-distutils`.
+
+  [ Synapse Packaging team ]
+  * New synapse release 1.26.0.
+
+ -- Synapse Packaging team <packages@matrix.org>  Wed, 27 Jan 2021 12:43:35 -0500
+
+matrix-synapse-py3 (1.25.0) stable; urgency=medium
+
+  [ Dan Callahan ]
+  * Update dependencies to account for the removal of the transitional
+    dh-systemd package from Debian Bullseye.
+
+  [ Synapse Packaging team ]
+  * New synapse release 1.25.0.
+
+ -- Synapse Packaging team <packages@matrix.org>  Wed, 13 Jan 2021 10:14:55 +0000
+
 matrix-synapse-py3 (1.24.0) stable; urgency=medium
 
   * New synapse release 1.24.0.
 
  -- Synapse Packaging team <packages@matrix.org>  Wed, 09 Dec 2020 10:14:30 +0000
 
+matrix-synapse-py3 (1.23.1) stable; urgency=medium
+
+  * New synapse release 1.23.1.
+
+ -- Synapse Packaging team <packages@matrix.org>  Wed, 09 Dec 2020 10:40:39 +0000
+
 matrix-synapse-py3 (1.23.0) stable; urgency=medium
 
   * New synapse release 1.23.0.
diff --git a/debian/control b/debian/control
index bae14b41e4..8167a901a4 100644
--- a/debian/control
+++ b/debian/control
@@ -3,9 +3,11 @@ Section: contrib/python
 Priority: extra
 Maintainer: Synapse Packaging team <packages@matrix.org>
 # keep this list in sync with the build dependencies in docker/Dockerfile-dhvirtualenv.
+# TODO: Remove the dependency on dh-systemd after dropping support for Ubuntu xenial
+#       On all other supported releases, it's merely a transitional package which
+#       does nothing but depends on debhelper (> 9.20160709)
 Build-Depends:
- debhelper (>= 9),
- dh-systemd,
+ debhelper (>= 9.20160709) | dh-systemd,
  dh-virtualenv (>= 1.1),
  libsystemd-dev,
  libpq-dev,
@@ -29,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1)
 Depends:
  adduser,
  debconf,
- python3-distutils|libpython3-stdlib (<< 3.6),
  ${misc:Depends},
  ${shlibs:Depends},
  ${synapse:pydepends},
diff --git a/demo/webserver.py b/demo/webserver.py
deleted file mode 100644
index ba176d3bd2..0000000000
--- a/demo/webserver.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import argparse
-import BaseHTTPServer
-import os
-import SimpleHTTPServer
-import cgi, logging
-
-from daemonize import Daemonize
-
-
-class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
-    UPLOAD_PATH = "upload"
-
-    """
-    Accept all post request as file upload
-    """
-
-    def do_POST(self):
-
-        path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
-        length = self.headers["content-length"]
-        data = self.rfile.read(int(length))
-
-        with open(path, "wb") as fh:
-            fh.write(data)
-
-        self.send_response(200)
-        self.send_header("Content-Type", "application/json")
-        self.end_headers()
-
-        # Return the absolute path of the uploaded file
-        self.wfile.write('{"url":"/%s"}' % path)
-
-
-def setup():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("directory")
-    parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
-    parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
-    args = parser.parse_args()
-
-    # Get absolute path to directory to serve, as daemonize changes to '/'
-    os.chdir(args.directory)
-    dr = os.getcwd()
-
-    httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
-
-    def run():
-        os.chdir(dr)
-        httpd.serve_forever()
-
-    daemon = Daemonize(
-        app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
-    )
-
-    daemon.start()
-
-
-if __name__ == "__main__":
-    setup()
diff --git a/docker/Dockerfile b/docker/Dockerfile
index afd896ffc1..d619ee08ed 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -28,11 +28,13 @@ RUN apt-get update && apt-get install -y \
     libwebp-dev \
     libxml++2.6-dev \
     libxslt1-dev \
+    rustc \
     zlib1g-dev \
  && rm -rf /var/lib/apt/lists/*
 
 # Build dependencies that are not available as wheels, to speed up rebuilds
 RUN pip install --prefix="/install" --no-warn-script-location \
+        cryptography \
         frozendict \
         jaeger-client \
         opentracing \
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index 2b7f01f7f7..0d74630370 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -27,6 +27,7 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
         wget
 
 # fetch and unpack the package
+# TODO: Upgrade to 1.2.2 once xenial is dropped
 RUN mkdir /dh-virtualenv
 RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
 RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
@@ -50,17 +51,22 @@ FROM ${distro}
 ARG distro=""
 ENV distro ${distro}
 
+# Python < 3.7 assumes LANG="C" means ASCII-only and throws on printing unicode
+# http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
 # Install the build dependencies
 #
 # NB: keep this list in sync with the list of build-deps in debian/control
 # TODO: it would be nice to do that automatically.
+# TODO: Remove the dh-systemd stanza after dropping support for Ubuntu xenial
+#       it's a transitional package on all other, more recent releases
 RUN apt-get update -qq -o Acquire::Languages=none \
     && env DEBIAN_FRONTEND=noninteractive apt-get install \
         -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \
         build-essential \
         debhelper \
         devscripts \
-        dh-systemd \
         libsystemd-dev \
         lsb-release \
         pkg-config \
@@ -70,7 +76,10 @@ RUN apt-get update -qq -o Acquire::Languages=none \
         python3-venv \
         sqlite3 \
         libpq-dev \
-        xmlsec1
+        xmlsec1 \
+    && ( env DEBIAN_FRONTEND=noninteractive apt-get install \
+         -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \
+         dh-systemd || true )
 
 COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /
 
diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index a808485c12..2ed570a5d1 100644
--- a/docker/conf/homeserver.yaml
+++ b/docker/conf/homeserver.yaml
@@ -198,12 +198,10 @@ old_signing_keys: {}
 key_refresh_interval: "1d" # 1 Day.
 
 # The trusted servers to download signing keys from.
-perspectives:
-  servers:
-    "matrix.org":
-      verify_keys:
-        "ed25519:auto":
-          key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+trusted_key_servers:
+  - server_name: matrix.org
+    verify_keys:
+      "ed25519:auto": "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
 
 password_config:
    enabled: true
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
index 71137c6dfc..90faeaaef0 100644
--- a/docs/admin_api/media_admin_api.md
+++ b/docs/admin_api/media_admin_api.md
@@ -1,3 +1,15 @@
+# Contents
+- [List all media in a room](#list-all-media-in-a-room)
+- [Quarantine media](#quarantine-media)
+  * [Quarantining media by ID](#quarantining-media-by-id)
+  * [Quarantining media in a room](#quarantining-media-in-a-room)
+  * [Quarantining all media of a user](#quarantining-all-media-of-a-user)
+  * [Protecting media from being quarantined](#protecting-media-from-being-quarantined)
+- [Delete local media](#delete-local-media)
+  * [Delete a specific local media](#delete-a-specific-local-media)
+  * [Delete local media by date or size](#delete-local-media-by-date-or-size)
+- [Purge Remote Media API](#purge-remote-media-api)
+
 # List all media in a room
 
 This API gets a list of known media in a room.
@@ -11,16 +23,16 @@ To use it, you will need to authenticate by providing an `access_token` for a
 server admin: see [README.rst](README.rst).
 
 The API returns a JSON body like the following:
-```
+```json
 {
-    "local": [
-        "mxc://localhost/xwvutsrqponmlkjihgfedcba",
-        "mxc://localhost/abcdefghijklmnopqrstuvwx"
-    ],
-    "remote": [
-        "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
-        "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
-    ]
+  "local": [
+    "mxc://localhost/xwvutsrqponmlkjihgfedcba",
+    "mxc://localhost/abcdefghijklmnopqrstuvwx"
+  ],
+  "remote": [
+    "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
+    "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
+  ]
 }
 ```
 
@@ -48,7 +60,7 @@ form of `abcdefg12345...`.
 
 Response:
 
-```
+```json
 {}
 ```
 
@@ -68,14 +80,18 @@ Where `room_id` is in the form of `!roomid12345:example.org`.
 
 Response:
 
-```
+```json
 {
-  "num_quarantined": 10  # The number of media items successfully quarantined
+  "num_quarantined": 10
 }
 ```
 
+The following fields are returned in the JSON response body:
+
+* `num_quarantined`: integer - The number of media items successfully quarantined
+
 Note that there is a legacy endpoint, `POST
-/_synapse/admin/v1/quarantine_media/<room_id >`, that operates the same.
+/_synapse/admin/v1/quarantine_media/<room_id>`, that operates the same.
 However, it is deprecated and may be removed in a future release.
 
 ## Quarantining all media of a user
@@ -92,23 +108,52 @@ POST /_synapse/admin/v1/user/<user_id>/media/quarantine
 {}
 ```
 
-Where `user_id` is in the form of `@bob:example.org`.
+URL Parameters
+
+* `user_id`: string - User ID in the form of `@bob:example.org`
 
 Response:
 
-```
+```json
 {
-  "num_quarantined": 10  # The number of media items successfully quarantined
+  "num_quarantined": 10
 }
 ```
 
+The following fields are returned in the JSON response body:
+
+* `num_quarantined`: integer - The number of media items successfully quarantined
+
+## Protecting media from being quarantined
+
+This API protects a single piece of local media from being quarantined using the
+above APIs. This is useful for sticker packs and other shared media which you do
+not want to get quarantined, especially when
+[quarantining media in a room](#quarantining-media-in-a-room).
+
+Request:
+
+```
+POST /_synapse/admin/v1/media/protect/<media_id>
+
+{}
+```
+
+Where `media_id` is in the  form of `abcdefg12345...`.
+
+Response:
+
+```json
+{}
+```
+
 # Delete local media
 This API deletes the *local* media from the disk of your own server.
 This includes any local thumbnails and copies of media downloaded from
 remote homeservers.
 This API will not affect media that has been uploaded to external
 media repositories (e.g https://github.com/turt2live/matrix-media-repo/).
-See also [purge_remote_media.rst](purge_remote_media.rst).
+See also [Purge Remote Media API](#purge-remote-media-api).
 
 ## Delete a specific local media
 Delete a specific `media_id`.
@@ -129,12 +174,12 @@ URL Parameters
 Response:
 
 ```json
-    {
-       "deleted_media": [
-          "abcdefghijklmnopqrstuvwx"
-       ],
-       "total": 1
-    }
+{
+  "deleted_media": [
+    "abcdefghijklmnopqrstuvwx"
+  ],
+  "total": 1
+}
 ```
 
 The following fields are returned in the JSON response body:
@@ -167,16 +212,51 @@ If `false` these files will be deleted. Defaults to `true`.
 Response:
 
 ```json
-    {
-       "deleted_media": [
-          "abcdefghijklmnopqrstuvwx",
-          "abcdefghijklmnopqrstuvwz"
-       ],
-       "total": 2
-    }
+{
+  "deleted_media": [
+    "abcdefghijklmnopqrstuvwx",
+    "abcdefghijklmnopqrstuvwz"
+  ],
+  "total": 2
+}
 ```
 
 The following fields are returned in the JSON response body:
 
 * `deleted_media`: an array of strings - List of deleted `media_id`
 * `total`: integer - Total number of deleted `media_id`
+
+# Purge Remote Media API
+
+The purge remote media API allows server admins to purge old cached remote media.
+
+The API is:
+
+```
+POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
+
+{}
+```
+
+URL Parameters
+
+* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in ms.
+All cached media that was last accessed before this timestamp will be removed.
+
+Response:
+
+```json
+{
+  "deleted": 10
+}
+```
+
+The following fields are returned in the JSON response body:
+
+* `deleted`: integer - The number of media items successfully deleted
+
+To use it, you will need to authenticate by providing an `access_token` for a
+server admin: see [README.rst](README.rst).
+
+If the user re-requests purged remote media, synapse will re-request the media
+from the originating server.
diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst
deleted file mode 100644
index 00cb6b0589..0000000000
--- a/docs/admin_api/purge_remote_media.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-Purge Remote Media API
-======================
-
-The purge remote media API allows server admins to purge old cached remote
-media.
-
-The API is::
-
-    POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms>
-
-    {}
-
-\... which will remove all cached media that was last accessed before
-``<unix_timestamp_in_ms>``.
-
-To use it, you will need to authenticate by providing an ``access_token`` for a
-server admin: see `README.rst <README.rst>`_.
-
-If the user re-requests purged remote media, synapse will re-request the media
-from the originating server.
diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md
index ae01a543c6..54fea2db6d 100644
--- a/docs/admin_api/purge_room.md
+++ b/docs/admin_api/purge_room.md
@@ -1,12 +1,13 @@
-Purge room API
-==============
+Deprecated: Purge room API
+==========================
+
+**The old Purge room API is deprecated and will be removed in a future release.
+See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
 
 This API will remove all trace of a room from your database.
 
 All local users must have left the room before it can be removed.
 
-See also: [Delete Room API](rooms.md#delete-room-api)
-
 The API is:
 
 ```
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 004a802e17..3832b36407 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -1,3 +1,16 @@
+# Contents
+- [List Room API](#list-room-api)
+  * [Parameters](#parameters)
+  * [Usage](#usage)
+- [Room Details API](#room-details-api)
+- [Room Members API](#room-members-api)
+- [Delete Room API](#delete-room-api)
+  * [Parameters](#parameters-1)
+  * [Response](#response)
+  * [Undoing room shutdowns](#undoing-room-shutdowns)
+- [Make Room Admin API](#make-room-admin-api)
+- [Forward Extremities Admin API](#forward-extremities-admin-api)
+
 # List Room API
 
 The List Room admin API allows server admins to get a list of rooms on their
@@ -76,7 +89,7 @@ GET /_synapse/admin/v1/rooms
 
 Response:
 
-```
+```jsonc
 {
   "rooms": [
     {
@@ -128,7 +141,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM
 
 Response:
 
-```
+```json
 {
   "rooms": [
     {
@@ -163,7 +176,7 @@ GET /_synapse/admin/v1/rooms?order_by=size
 
 Response:
 
-```
+```jsonc
 {
   "rooms": [
     {
@@ -219,14 +232,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100
 
 Response:
 
-```
+```jsonc
 {
   "rooms": [
     {
       "room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
       "name": "Music Theory",
       "canonical_alias": "#musictheory:matrix.org",
-      "joined_members": 127
+      "joined_members": 127,
       "joined_local_members": 2,
       "version": "1",
       "creator": "@foo:matrix.org",
@@ -243,7 +256,7 @@ Response:
       "room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk",
       "name": "weechat-matrix",
       "canonical_alias": "#weechat-matrix:termina.org.uk",
-      "joined_members": 137
+      "joined_members": 137,
       "joined_local_members": 20,
       "version": "4",
       "creator": "@foo:termina.org.uk",
@@ -278,6 +291,7 @@ The following fields are possible in the JSON response body:
 * `canonical_alias` - The canonical (main) alias address of the room.
 * `joined_members` - How many users are currently in the room.
 * `joined_local_members` - How many local users are currently in the room.
+* `joined_local_devices` - How many local devices are currently in the room.
 * `version` - The version of the room as a string.
 * `creator` - The `user_id` of the room creator.
 * `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
@@ -300,15 +314,16 @@ GET /_synapse/admin/v1/rooms/<room_id>
 
 Response:
 
-```
+```json
 {
   "room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
   "name": "Music Theory",
   "avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV",
   "topic": "Theory, Composition, Notation, Analysis",
   "canonical_alias": "#musictheory:matrix.org",
-  "joined_members": 127
+  "joined_members": 127,
   "joined_local_members": 2,
+  "joined_local_devices": 2,
   "version": "1",
   "creator": "@foo:matrix.org",
   "encryption": null,
@@ -342,23 +357,51 @@ GET /_synapse/admin/v1/rooms/<room_id>/members
 
 Response:
 
-```
+```json
 {
   "members": [
     "@foo:matrix.org",
     "@bar:matrix.org",
-    "@foobar:matrix.org
-    ],
+    "@foobar:matrix.org"
+  ],
   "total": 3
 }
 ```
 
+# Room State API
+
+The Room State admin API allows server admins to get a list of all state events in a room.
+
+The response includes the following fields:
+
+* `state` - The current state of the room at the time of request.
+
+## Usage
+
+A standard request:
+
+```
+GET /_synapse/admin/v1/rooms/<room_id>/state
+
+{}
+```
+
+Response:
+
+```json
+{
+  "state": [
+    {"type": "m.room.create", "state_key": "", "etc": true},
+    {"type": "m.room.power_levels", "state_key": "", "etc": true},
+    {"type": "m.room.name", "state_key": "", "etc": true}
+  ]
+}
+```
+
 # Delete Room API
 
 The Delete Room admin API allows server admins to remove rooms from server
 and block these rooms.
-It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
-and "[Purge room](purge_room.md)" API.
 
 Shuts down a room. Moves all local users and room aliases automatically to a
 new room if `new_room_user_id` is set. Otherwise local users only
@@ -455,3 +498,99 @@ The following fields are returned in the JSON response body:
 * `local_aliases` - An array of strings representing the local aliases that were migrated from
                     the old room to the new.
 * `new_room_id` - A string representing the room ID of the new room.
+
+
+## Undoing room shutdowns
+
+*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level,
+the structure can and does change without notice.
+
+First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
+never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible
+to recover at all:
+
+* If the room was invite-only, your users will need to be re-invited.
+* If the room no longer has any members at all, it'll be impossible to rejoin.
+* The first user to rejoin will have to do so via an alias on a different server.
+
+With all that being said, if you still want to try and recover the room:
+
+1. For safety reasons, shut down Synapse.
+2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
+   * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
+   * The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
+3. Restart Synapse.
+
+You will have to manually handle, if you so choose, the following:
+
+* Aliases that would have been redirected to the Content Violation room.
+* Users that would have been booted from the room (and will have been force-joined to the Content Violation room).
+* Removal of the Content Violation room if desired.
+
+
+# Make Room Admin API
+
+Grants another user the highest power available to a local user who is in the room.
+If the user is not in the room, and it is not publicly joinable, then invite the user.
+
+By default the server admin (the caller) is granted power, but another user can
+optionally be specified, e.g.:
+
+```
+    POST /_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+    {
+        "user_id": "@foo:example.com"
+    }
+```
+
+# Forward Extremities Admin API
+
+Enables querying and deleting forward extremities from rooms. When a lot of forward
+extremities accumulate in a room, performance can become degraded. For details, see 
+[#1760](https://github.com/matrix-org/synapse/issues/1760).
+
+## Check for forward extremities
+
+To check the status of forward extremities for a room:
+
+```
+    GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned:
+
+```json
+{
+  "count": 1,
+  "results": [
+    {
+      "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh",
+      "state_group": 439,
+      "depth": 123,
+      "received_ts": 1611263016761
+    }
+  ]
+}    
+```
+
+## Deleting forward extremities
+
+**WARNING**: Please ensure you know what you're doing and have read 
+the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
+Under no situations should this API be executed as an automated maintenance task!
+
+If a room has lots of forward extremities, the extra can be
+deleted as follows:
+
+```
+    DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+```
+
+A response as follows will be returned, indicating the amount of forward extremities
+that were deleted.
+
+```json
+{
+  "deleted": 1
+}
+```
diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
index 9b1cb1c184..856a629487 100644
--- a/docs/admin_api/shutdown_room.md
+++ b/docs/admin_api/shutdown_room.md
@@ -1,4 +1,7 @@
-# Shutdown room API
+# Deprecated: Shutdown room API
+
+**The old Shutdown room API is deprecated and will be removed in a future release.
+See the new [Delete Room API](rooms.md#delete-room-api) for more details.**
 
 Shuts down a room, preventing new joins and moves local users and room aliases automatically
 to a new room. The new room will be created with the user specified by the
@@ -10,8 +13,6 @@ disallow any further invites or joins.
 The local server will only have the power to move local user and room aliases to
 the new room. Users on other servers will be unaffected.
 
-See also: [Delete Room API](rooms.md#delete-room-api)
-
 ## API
 
 You will need to authenticate with an access token for an admin user.
diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst
index 1473a3d4e3..1eb674939e 100644
--- a/docs/admin_api/user_admin_api.rst
+++ b/docs/admin_api/user_admin_api.rst
@@ -30,7 +30,12 @@ It returns a JSON body like the following:
         ],
         "avatar_url": "<avatar_url>",
         "admin": false,
-        "deactivated": false
+        "deactivated": false,
+        "password_hash": "$2b$12$p9B4GkqYdRTPGD",
+        "creation_ts": 1560432506,
+        "appservice_id": null,
+        "consent_server_notice_sent": null,
+        "consent_version": null
     }
 
 URL parameters:
@@ -93,6 +98,8 @@ Body parameters:
 
 - ``deactivated``, optional. If unspecified, deactivation state will be left
   unchanged on existing accounts and set to ``false`` for new accounts.
+  A user cannot be erased by deactivating with this API. For details on deactivating users see
+  `Deactivate Account <#deactivate-account>`_.
 
 If the user already exists then optional parameters default to the current value.
 
@@ -139,7 +146,6 @@ A JSON body is returned with the following shape:
         "users": [
             {
                 "name": "<user_id1>",
-                "password_hash": "<password_hash1>",
                 "is_guest": 0,
                 "admin": 0,
                 "user_type": null,
@@ -148,7 +154,6 @@ A JSON body is returned with the following shape:
                 "avatar_url": null
             }, {
                 "name": "<user_id2>",
-                "password_hash": "<password_hash2>",
                 "is_guest": 0,
                 "admin": 1,
                 "user_type": null,
@@ -245,6 +250,25 @@ server admin: see `README.rst <README.rst>`_.
 The erase parameter is optional and defaults to ``false``.
 An empty body may be passed for backwards compatibility.
 
+The following actions are performed when deactivating an user:
+
+- Try to unpind 3PIDs from the identity server
+- Remove all 3PIDs from the homeserver
+- Delete all devices and E2EE keys
+- Delete all access tokens
+- Delete the password hash
+- Removal from all rooms the user is a member of
+- Remove the user from the user directory
+- Reject all pending invites
+- Remove all account validity information related to the user
+
+The following additional actions are performed during deactivation if``erase``
+is set to ``true``:
+
+- Remove the user's display name
+- Remove the user's avatar URL
+- Mark the user as erased
+
 
 Reset password
 ==============
@@ -334,6 +358,10 @@ A response body like the following is returned:
         "total": 2
     }
 
+The server returns the list of rooms of which the user and the server
+are member. If the user is local, all the rooms of which the user is
+member are returned.
+
 **Parameters**
 
 The following parameters should be set in the URL:
@@ -732,3 +760,33 @@ The following fields are returned in the JSON response body:
 - ``total`` - integer - Number of pushers.
 
 See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
+
+Shadow-banning users
+====================
+
+Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
+A shadow-banned users receives successful responses to their client-server API requests,
+but the events are not propagated into rooms. This can be an effective tool as it
+(hopefully) takes longer for the user to realise they are being moderated before
+pivoting to another account.
+
+Shadow-banning a user should be used as a tool of last resort and may lead to confusing
+or broken behaviour for the client. A shadow-banned user will not receive any
+notification and it is generally more appropriate to ban or kick abusive users.
+A shadow-banned user will be unable to contact anyone on the server.
+
+The API is::
+
+  POST /_synapse/admin/v1/users/<user_id>/shadow_ban
+
+To use it, you will need to authenticate by providing an ``access_token`` for a
+server admin: see `README.rst <README.rst>`_.
+
+An empty JSON dict is returned.
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
+  be local.
diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot
new file mode 100644
index 0000000000..978d579ada
--- /dev/null
+++ b/docs/auth_chain_diff.dot
@@ -0,0 +1,32 @@
+digraph auth {
+    nodesep=0.5;
+    rankdir="RL";
+
+    C [label="Create (1,1)"];
+
+    BJ [label="Bob's Join (2,1)", color=red];
+    BJ2 [label="Bob's Join (2,2)", color=red];
+    BJ2 -> BJ [color=red, dir=none];
+
+    subgraph cluster_foo {
+        A1 [label="Alice's invite (4,1)", color=blue];
+        A2 [label="Alice's Join (4,2)", color=blue];
+        A3 [label="Alice's Join (4,3)", color=blue];
+        A3 -> A2 -> A1 [color=blue, dir=none];
+        color=none;
+    }
+
+    PL1 [label="Power Level (3,1)", color=darkgreen];
+    PL2 [label="Power Level (3,2)", color=darkgreen];
+    PL2 -> PL1 [color=darkgreen, dir=none];
+
+    {rank = same; C; BJ; PL1; A1;}
+
+    A1 -> C [color=grey];
+    A1 -> BJ [color=grey];
+    PL1 -> C [color=grey];
+    BJ2 -> PL1 [penwidth=2];
+
+    A3 -> PL2 [penwidth=2];
+    A1 -> PL1 -> BJ -> C [penwidth=2];
+}
diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png
new file mode 100644
index 0000000000..771c07308f
--- /dev/null
+++ b/docs/auth_chain_diff.dot.png
Binary files differdiff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md
new file mode 100644
index 0000000000..30f72a70da
--- /dev/null
+++ b/docs/auth_chain_difference_algorithm.md
@@ -0,0 +1,108 @@
+# Auth Chain Difference Algorithm
+
+The auth chain difference algorithm is used by V2 state resolution, where a
+naive implementation can be a significant source of CPU and DB usage.
+
+### Definitions
+
+A *state set* is a set of state events; e.g. the input of a state resolution
+algorithm is a collection of state sets.
+
+The *auth chain* of a set of events are all the events' auth events and *their*
+auth events, recursively (i.e. the events reachable by walking the graph induced
+by an event's auth events links).
+
+The *auth chain difference* of a collection of state sets is the union minus the
+intersection of the sets of auth chains corresponding to the state sets, i.e an
+event is in the auth chain difference if it is reachable by walking the auth
+event graph from at least one of the state sets but not from *all* of the state
+sets.
+
+## Breadth First Walk Algorithm
+
+A way of calculating the auth chain difference without calculating the full auth
+chains for each state set is to do a parallel breadth first walk (ordered by
+depth) of each state set's auth chain. By tracking which events are reachable
+from each state set we can finish early if every pending event is reachable from
+every state set.
+
+This can work well for state sets that have a small auth chain difference, but
+can be very inefficient for larger differences. However, this algorithm is still
+used if we don't have a chain cover index for the room (e.g. because we're in
+the process of indexing it).
+
+## Chain Cover Index
+
+Synapse computes auth chain differences by pre-computing a "chain cover" index
+for the auth chain in a room, allowing efficient reachability queries like "is
+event A in the auth chain of event B". This is done by assigning every event a
+*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
+between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
+is in the auth chain of `B`) if and only if either:
+
+1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
+   sequence number; or
+2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
+   `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
+
+There are actually two potential implementations, one where we store links from
+each chain to every other reachable chain (the transitive closure of the links
+graph), and one where we remove redundant links (the transitive reduction of the
+links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
+would not be stored. Synapse uses the former implementations so that it doesn't
+need to recurse to test reachability between chains.
+
+### Example
+
+An example auth graph would look like the following, where chains have been
+formed based on type/state_key and are denoted by colour and are labelled with
+`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
+are those that would be remove in the second implementation described above).
+
+![Example](auth_chain_diff.dot.png)
+
+Note that we don't include all links between events and their auth events, as
+most of those links would be redundant. For example, all events point to the
+create event, but each chain only needs the one link from it's base to the
+create event.
+
+## Using the Index
+
+This index can be used to calculate the auth chain difference of the state sets
+by looking at the chain ID and sequence numbers reachable from each state set:
+
+1. For every state set lookup the chain ID/sequence numbers of each state event
+2. Use the index to find all chains and the maximum sequence number reachable
+   from each state set.
+3. The auth chain difference is then all events in each chain that have sequence
+   numbers between the maximum sequence number reachable from *any* state set and
+   the minimum reachable by *all* state sets (if any).
+
+Note that steps 2 is effectively calculating the auth chain for each state set
+(in terms of chain IDs and sequence numbers), and step 3 is calculating the
+difference between the union and intersection of the auth chains.
+
+### Worked Example
+
+For example, given the above graph, we can calculate the difference between
+state sets consisting of:
+
+1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
+2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
+
+Using the index we see that the following auth chains are reachable from each
+state set:
+
+1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
+2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
+
+And so, for each the ranges that are in the auth chain difference:
+1. Chain 1: None, (since everything can reach the create event).
+2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
+   sets and the maximum reachable is `2` (corresponding to Bob's second join).
+3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
+   level).
+4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
+
+So the final result is: Bob's second join `(2,2)`, the second power level
+`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.
diff --git a/docs/dev/cas.md b/docs/dev/cas.md
index f8d02cc82c..592b2d8d4f 100644
--- a/docs/dev/cas.md
+++ b/docs/dev/cas.md
@@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
 You should now have a Django project configured to serve CAS authentication with
 a single user created.
 
-## Configure Synapse (and Riot) to use CAS
+## Configure Synapse (and Element) to use CAS
 
 1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
    running Django test server:
@@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
 
 ## Testing the configuration
 
-Then in Riot:
+Then in Element:
 
-1. Visit the login page with a Riot pointing at your homeserver.
+1. Visit the login page with a Element pointing at your homeserver.
 2. Click the Single Sign-On button.
 3. Login using the credentials created with `createsuperuser`.
 4. You should be logged in.
diff --git a/docs/openid.md b/docs/openid.md
index da391f74aa..9d19368845 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -42,40 +42,41 @@ as follows:
  * For other installation mechanisms, see the documentation provided by the
    maintainer.
 
-To enable the OpenID integration, you should then add an `oidc_config` section
-to your configuration file (or uncomment the `enabled: true` line in the
-existing section). See [sample_config.yaml](./sample_config.yaml) for some
-sample settings, as well as the text below for example configurations for
-specific providers.
+To enable the OpenID integration, you should then add a section to the `oidc_providers`
+setting in your configuration file (or uncomment one of the existing examples).
+See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as
+the text below for example configurations for specific providers.
 
 ## Sample configs
 
 Here are a few configs for providers that should work with Synapse.
 
 ### Microsoft Azure Active Directory
-Azure AD can act as an OpenID Connect Provider. Register a new application under 
+Azure AD can act as an OpenID Connect Provider. Register a new application under
 *App registrations* in the Azure AD management console. The RedirectURI for your
-application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback`
+application should point to your matrix server:
+`[synapse public baseurl]/_synapse/client/oidc/callback`
 
-Go to *Certificates & secrets* and register a new client secret. Make note of your 
+Go to *Certificates & secrets* and register a new client secret. Make note of your
 Directory (tenant) ID as it will be used in the Azure links.
 Edit your Synapse config file and change the `oidc_config` section:
 
 ```yaml
-oidc_config:
-   enabled: true
-   issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
-   client_id: "<client id>"
-   client_secret: "<client secret>"
-   scopes: ["openid", "profile"]
-   authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
-   token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
-   userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
-
-   user_mapping_provider:
-     config:
-       localpart_template: "{{ user.preferred_username.split('@')[0] }}"
-       display_name_template: "{{ user.name }}"
+oidc_providers:
+  - idp_id: microsoft
+    idp_name: Microsoft
+    issuer: "https://login.microsoftonline.com/<tenant id>/v2.0"
+    client_id: "<client id>"
+    client_secret: "<client secret>"
+    scopes: ["openid", "profile"]
+    authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize"
+    token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token"
+    userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo"
+
+    user_mapping_provider:
+      config:
+        localpart_template: "{{ user.preferred_username.split('@')[0] }}"
+        display_name_template: "{{ user.name }}"
 ```
 
 ### [Dex][dex-idp]
@@ -94,7 +95,7 @@ staticClients:
 - id: synapse
   secret: secret
   redirectURIs:
-  - '[synapse public baseurl]/_synapse/oidc/callback'
+  - '[synapse public baseurl]/_synapse/client/oidc/callback'
   name: 'Synapse'
 ```
 
@@ -103,21 +104,22 @@ Run with `dex serve examples/config-dev.yaml`.
 Synapse config:
 
 ```yaml
-oidc_config:
-   enabled: true
-   skip_verification: true # This is needed as Dex is served on an insecure endpoint
-   issuer: "http://127.0.0.1:5556/dex"
-   client_id: "synapse"
-   client_secret: "secret"
-   scopes: ["openid", "profile"]
-   user_mapping_provider:
-     config:
-       localpart_template: "{{ user.name }}"
-       display_name_template: "{{ user.name|capitalize }}"
+oidc_providers:
+  - idp_id: dex
+    idp_name: "My Dex server"
+    skip_verification: true # This is needed as Dex is served on an insecure endpoint
+    issuer: "http://127.0.0.1:5556/dex"
+    client_id: "synapse"
+    client_secret: "secret"
+    scopes: ["openid", "profile"]
+    user_mapping_provider:
+      config:
+        localpart_template: "{{ user.name }}"
+        display_name_template: "{{ user.name|capitalize }}"
 ```
 ### [Keycloak][keycloak-idp]
 
-[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. 
+[Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
 
 Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
 
@@ -139,7 +141,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
 | Enabled | `On` |
 | Client Protocol | `openid-connect` |
 | Access Type | `confidential` |
-| Valid Redirect URIs | `[synapse public baseurl]/_synapse/oidc/callback` |
+| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` |
 
 5. Click `Save`
 6. On the Credentials tab, update the fields:
@@ -152,17 +154,22 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
 8. Copy Secret
 
 ```yaml
-oidc_config:
-   enabled: true
-   issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
-   client_id: "synapse"
-   client_secret: "copy secret generated from above"
-   scopes: ["openid", "profile"]
+oidc_providers:
+  - idp_id: keycloak
+    idp_name: "My KeyCloak server"
+    issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}"
+    client_id: "synapse"
+    client_secret: "copy secret generated from above"
+    scopes: ["openid", "profile"]
+    user_mapping_provider:
+      config:
+        localpart_template: "{{ user.preferred_username }}"
+        display_name_template: "{{ user.name }}"
 ```
 ### [Auth0][auth0]
 
 1. Create a regular web application for Synapse
-2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback`
+2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback`
 3. Add a rule to add the `preferred_username` claim.
    <details>
     <summary>Code sample</summary>
@@ -187,16 +194,17 @@ oidc_config:
 Synapse config:
 
 ```yaml
-oidc_config:
-   enabled: true
-   issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
-   client_id: "your-client-id" # TO BE FILLED
-   client_secret: "your-client-secret" # TO BE FILLED
-   scopes: ["openid", "profile"]
-   user_mapping_provider:
-     config:
-       localpart_template: "{{ user.preferred_username }}"
-       display_name_template: "{{ user.name }}"
+oidc_providers:
+  - idp_id: auth0
+    idp_name: Auth0
+    issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED
+    client_id: "your-client-id" # TO BE FILLED
+    client_secret: "your-client-secret" # TO BE FILLED
+    scopes: ["openid", "profile"]
+    user_mapping_provider:
+      config:
+        localpart_template: "{{ user.preferred_username }}"
+        display_name_template: "{{ user.name }}"
 ```
 
 ### GitHub
@@ -210,26 +218,28 @@ login mechanism needs an attribute to uniquely identify users, and that endpoint
 does not return a `sub` property, an alternative `subject_claim` has to be set.
 
 1. Create a new OAuth application: https://github.com/settings/applications/new.
-2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`.
+2. Set the callback URL to `[synapse public baseurl]/_synapse/client/oidc/callback`.
 
 Synapse config:
 
 ```yaml
-oidc_config:
-   enabled: true
-   discover: false
-   issuer: "https://github.com/"
-   client_id: "your-client-id" # TO BE FILLED
-   client_secret: "your-client-secret" # TO BE FILLED
-   authorization_endpoint: "https://github.com/login/oauth/authorize"
-   token_endpoint: "https://github.com/login/oauth/access_token"
-   userinfo_endpoint: "https://api.github.com/user"
-   scopes: ["read:user"]
-   user_mapping_provider:
-     config:
-       subject_claim: "id"
-       localpart_template: "{{ user.login }}"
-       display_name_template: "{{ user.name }}"
+oidc_providers:
+  - idp_id: github
+    idp_name: Github
+    idp_brand: "org.matrix.github"  # optional: styling hint for clients
+    discover: false
+    issuer: "https://github.com/"
+    client_id: "your-client-id" # TO BE FILLED
+    client_secret: "your-client-secret" # TO BE FILLED
+    authorization_endpoint: "https://github.com/login/oauth/authorize"
+    token_endpoint: "https://github.com/login/oauth/access_token"
+    userinfo_endpoint: "https://api.github.com/user"
+    scopes: ["read:user"]
+    user_mapping_provider:
+      config:
+        subject_claim: "id"
+        localpart_template: "{{ user.login }}"
+        display_name_template: "{{ user.name }}"
 ```
 
 ### [Google][google-idp]
@@ -239,60 +249,142 @@ oidc_config:
 2. add an "OAuth Client ID" for a Web Application under "Credentials".
 3. Copy the Client ID and Client Secret, and add the following to your synapse config:
    ```yaml
-   oidc_config:
-     enabled: true
-     issuer: "https://accounts.google.com/"
-     client_id: "your-client-id" # TO BE FILLED
-     client_secret: "your-client-secret" # TO BE FILLED
-     scopes: ["openid", "profile"]
-     user_mapping_provider:
-       config:
-         localpart_template: "{{ user.given_name|lower }}"
-         display_name_template: "{{ user.name }}"
+   oidc_providers:
+     - idp_id: google
+       idp_name: Google
+       idp_brand: "org.matrix.google"  # optional: styling hint for clients
+       issuer: "https://accounts.google.com/"
+       client_id: "your-client-id" # TO BE FILLED
+       client_secret: "your-client-secret" # TO BE FILLED
+       scopes: ["openid", "profile"]
+       user_mapping_provider:
+         config:
+           localpart_template: "{{ user.given_name|lower }}"
+           display_name_template: "{{ user.name }}"
    ```
 4. Back in the Google console, add this Authorized redirect URI: `[synapse
-   public baseurl]/_synapse/oidc/callback`.
+   public baseurl]/_synapse/client/oidc/callback`.
 
 ### Twitch
 
 1. Setup a developer account on [Twitch](https://dev.twitch.tv/)
 2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/)
-3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback`
+3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/client/oidc/callback`
 
 Synapse config:
 
 ```yaml
-oidc_config:
-  enabled: true
-  issuer: "https://id.twitch.tv/oauth2/"
-  client_id: "your-client-id" # TO BE FILLED
-  client_secret: "your-client-secret" # TO BE FILLED
-  client_auth_method: "client_secret_post"
-  user_mapping_provider:
-    config:
-      localpart_template: "{{ user.preferred_username }}"
-      display_name_template: "{{ user.name }}"
+oidc_providers:
+  - idp_id: twitch
+    idp_name: Twitch
+    issuer: "https://id.twitch.tv/oauth2/"
+    client_id: "your-client-id" # TO BE FILLED
+    client_secret: "your-client-secret" # TO BE FILLED
+    client_auth_method: "client_secret_post"
+    user_mapping_provider:
+      config:
+        localpart_template: "{{ user.preferred_username }}"
+        display_name_template: "{{ user.name }}"
 ```
 
 ### GitLab
 
 1. Create a [new application](https://gitlab.com/profile/applications).
 2. Add the `read_user` and `openid` scopes.
-3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
+3. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback`
 
 Synapse config:
 
 ```yaml
-oidc_config:
-  enabled: true
-  issuer: "https://gitlab.com/"
-  client_id: "your-client-id" # TO BE FILLED
-  client_secret: "your-client-secret" # TO BE FILLED
-  client_auth_method: "client_secret_post"
-  scopes: ["openid", "read_user"]
-  user_profile_method: "userinfo_endpoint"
-  user_mapping_provider:
-    config:
-      localpart_template: '{{ user.nickname }}'
-      display_name_template: '{{ user.name }}'
+oidc_providers:
+  - idp_id: gitlab
+    idp_name: Gitlab
+    idp_brand: "org.matrix.gitlab"  # optional: styling hint for clients
+    issuer: "https://gitlab.com/"
+    client_id: "your-client-id" # TO BE FILLED
+    client_secret: "your-client-secret" # TO BE FILLED
+    client_auth_method: "client_secret_post"
+    scopes: ["openid", "read_user"]
+    user_profile_method: "userinfo_endpoint"
+    user_mapping_provider:
+      config:
+        localpart_template: '{{ user.nickname }}'
+        display_name_template: '{{ user.name }}'
+```
+
+### Facebook
+
+Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant
+one so requires a little more configuration.
+
+0. You will need a Facebook developer account. You can register for one
+   [here](https://developers.facebook.com/async/registration/).
+1. On the [apps](https://developers.facebook.com/apps/) page of the developer
+   console, "Create App", and choose "Build Connected Experiences".
+2. Once the app is created, add "Facebook Login" and choose "Web". You don't
+   need to go through the whole form here.
+3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings".
+   * Add `[synapse public baseurl]/_synapse/client/oidc/callback` as an OAuth Redirect
+     URL.
+4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID"
+   and "App Secret" for use below.
+
+Synapse config:
+
+```yaml
+  - idp_id: facebook
+    idp_name: Facebook
+    idp_brand: "org.matrix.facebook"  # optional: styling hint for clients
+    discover: false
+    issuer: "https://facebook.com"
+    client_id: "your-client-id" # TO BE FILLED
+    client_secret: "your-client-secret" # TO BE FILLED
+    scopes: ["openid", "email"]
+    authorization_endpoint: https://facebook.com/dialog/oauth
+    token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token
+    user_profile_method: "userinfo_endpoint"
+    userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture"
+    user_mapping_provider:
+      config:
+        subject_claim: "id"
+        display_name_template: "{{ user.name }}"
+```
+
+Relevant documents:
+ * https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow
+ * Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/
+ * Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user
+
+### Gitea
+
+Gitea is, like Github, not an OpenID provider, but just an OAuth2 provider.
+
+The [`/user` API endpoint](https://try.gitea.io/api/swagger#/user/userGetCurrent)
+can be used to retrieve information on the authenticated user. As the Synapse
+login mechanism needs an attribute to uniquely identify users, and that endpoint
+does not return a `sub` property, an alternative `subject_claim` has to be set.
+
+1. Create a new application.
+2. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback`
+
+Synapse config:
+
+```yaml
+oidc_providers:
+  - idp_id: gitea
+    idp_name: Gitea
+    discover: false
+    issuer: "https://your-gitea.com/"
+    client_id: "your-client-id" # TO BE FILLED
+    client_secret: "your-client-secret" # TO BE FILLED
+    client_auth_method: client_secret_post
+    scopes: [] # Gitea doesn't support Scopes
+    authorization_endpoint: "https://your-gitea.com/login/oauth/authorize"
+    token_endpoint: "https://your-gitea.com/login/oauth/access_token"
+    userinfo_endpoint: "https://your-gitea.com/api/v1/user"
+    user_mapping_provider:
+      config:
+        subject_claim: "id"
+        localpart_template: "{{ user.login }}"
+        display_name_template: "{{ user.full_name }}" 
 ```
diff --git a/docs/postgres.md b/docs/postgres.md
index c30cc1fd8c..680685d04e 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -18,7 +18,7 @@ connect to a postgres database.
     virtualenv](../INSTALL.md#installing-from-source), you can install
     the library with:
 
-        ~/synapse/env/bin/pip install matrix-synapse[postgres]
+        ~/synapse/env/bin/pip install "matrix-synapse[postgres]"
 
     (substituting the path to your virtualenv for `~/synapse/env`, if
     you used a different path). You will require the postgres
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 3439aa3594..f2fe4e1846 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -67,11 +67,12 @@ pid_file: DATADIR/homeserver.pid
 #
 #web_client_location: https://riot.example.com/
 
-# The public-facing base URL that clients use to access this HS
-# (not including _matrix/...). This is the same URL a user would
-# enter into the 'custom HS URL' field on their client. If you
-# use synapse with a reverse proxy, this should be the URL to reach
-# synapse via the proxy.
+# The public-facing base URL that clients use to access this Homeserver (not
+# including _matrix/...). This is the same URL a user might enter into the
+# 'Custom Homeserver URL' field on their client. If you use Synapse with a
+# reverse proxy, this should be the URL to reach Synapse via the proxy.
+# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
+# 'listeners' below).
 #
 #public_baseurl: https://example.com/
 
@@ -144,6 +145,47 @@ pid_file: DATADIR/homeserver.pid
 #
 #enable_search: false
 
+# Prevent outgoing requests from being sent to the following blacklisted IP address
+# CIDR ranges. If this option is not specified then it defaults to private IP
+# address ranges (see the example below).
+#
+# The blacklist applies to the outbound requests for federation, identity servers,
+# push servers, and for checking key validity for third-party invite events.
+#
+# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+# listed here, since they correspond to unroutable addresses.)
+#
+# This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
+#
+#ip_range_blacklist:
+#  - '127.0.0.0/8'
+#  - '10.0.0.0/8'
+#  - '172.16.0.0/12'
+#  - '192.168.0.0/16'
+#  - '100.64.0.0/10'
+#  - '192.0.0.0/24'
+#  - '169.254.0.0/16'
+#  - '198.18.0.0/15'
+#  - '192.0.2.0/24'
+#  - '198.51.100.0/24'
+#  - '203.0.113.0/24'
+#  - '224.0.0.0/4'
+#  - '::1/128'
+#  - 'fe80::/10'
+#  - 'fc00::/7'
+
+# List of IP address CIDR ranges that should be allowed for federation,
+# identity servers, push servers, and for checking key validity for
+# third-party invite events. This is useful for specifying exceptions to
+# wide-ranging blacklisted target IP ranges - e.g. for communication with
+# a push server only visible in your network.
+#
+# This whitelist overrides ip_range_blacklist and defaults to an empty
+# list.
+#
+#ip_range_whitelist:
+#   - '192.168.1.1'
+
 # List of ports that Synapse should listen on, their purpose and their
 # configuration.
 #
@@ -642,27 +684,6 @@ acme:
 #  - nyc.example.com
 #  - syd.example.com
 
-# Prevent federation requests from being sent to the following
-# blacklist IP address CIDR ranges. If this option is not specified, or
-# specified with an empty list, no ip range blacklist will be enforced.
-#
-# As of Synapse v1.4.0 this option also affects any outbound requests to identity
-# servers provided by user input.
-#
-# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
-# listed here, since they correspond to unroutable addresses.)
-#
-federation_ip_range_blacklist:
-  - '127.0.0.0/8'
-  - '10.0.0.0/8'
-  - '172.16.0.0/12'
-  - '192.168.0.0/16'
-  - '100.64.0.0/10'
-  - '169.254.0.0/16'
-  - '::1/128'
-  - 'fe80::/64'
-  - 'fc00::/7'
-
 # Report prometheus metrics on the age of PDUs being sent to and received from
 # the following domains. This can be used to give an idea of "delay" on inbound
 # and outbound federation, though be aware that any delay can be due to problems
@@ -799,6 +820,9 @@ log_config: "CONFDIR/SERVERNAME.log.config"
 #     users are joining rooms the server is already in (this is cheap) vs
 #     "remote" for when users are trying to join rooms not on the server (which
 #     can be more expensive)
+#   - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+#   - two for ratelimiting how often invites can be sent in a room or to a
+#     specific user.
 #
 # The defaults are as shown below.
 #
@@ -832,7 +856,18 @@ log_config: "CONFDIR/SERVERNAME.log.config"
 #  remote:
 #    per_second: 0.01
 #    burst_count: 3
-
+#
+#rc_3pid_validation:
+#  per_second: 0.003
+#  burst_count: 5
+#
+#rc_invites:
+#  per_room:
+#    per_second: 0.3
+#    burst_count: 10
+#  per_user:
+#    per_second: 0.003
+#    burst_count: 5
 
 # Ratelimiting settings for incoming federation
 #
@@ -953,9 +988,15 @@ media_store_path: "DATADIR/media_store"
 #  - '172.16.0.0/12'
 #  - '192.168.0.0/16'
 #  - '100.64.0.0/10'
+#  - '192.0.0.0/24'
 #  - '169.254.0.0/16'
+#  - '198.18.0.0/15'
+#  - '192.0.2.0/24'
+#  - '198.51.100.0/24'
+#  - '203.0.113.0/24'
+#  - '224.0.0.0/4'
 #  - '::1/128'
-#  - 'fe80::/64'
+#  - 'fe80::/10'
 #  - 'fc00::/7'
 
 # List of IP address CIDR ranges that the URL preview spider is allowed
@@ -1528,10 +1569,10 @@ trusted_key_servers:
 # enable SAML login.
 #
 # Once SAML support is enabled, a metadata file will be exposed at
-# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
+# https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to
 # use to configure your SAML IdP with. Alternatively, you can manually configure
 # the IdP to use an ACS location of
-# https://<server>:<port>/_matrix/saml2/authn_response.
+# https://<server>:<port>/_synapse/client/saml2/authn_response.
 #
 saml2_config:
   # `sp_config` is the configuration for the pysaml2 Service Provider.
@@ -1688,140 +1729,173 @@ saml2_config:
   #idp_entityid: 'https://our_idp/entityid'
 
 
-# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
+# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
+# and login.
 #
-# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
-# for some example configurations.
+# Options for each entry include:
 #
-oidc_config:
-  # Uncomment the following to enable authorization against an OpenID Connect
-  # server. Defaults to false.
-  #
-  #enabled: true
-
-  # Uncomment the following to disable use of the OIDC discovery mechanism to
-  # discover endpoints. Defaults to true.
-  #
-  #discover: false
-
-  # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
-  # discover the provider's endpoints.
-  #
-  # Required if 'enabled' is true.
-  #
-  #issuer: "https://accounts.example.com/"
-
-  # oauth2 client id to use.
-  #
-  # Required if 'enabled' is true.
-  #
-  #client_id: "provided-by-your-issuer"
-
-  # oauth2 client secret to use.
-  #
-  # Required if 'enabled' is true.
-  #
-  #client_secret: "provided-by-your-issuer"
-
-  # auth method to use when exchanging the token.
-  # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
-  # 'none'.
-  #
-  #client_auth_method: client_secret_post
-
-  # list of scopes to request. This should normally include the "openid" scope.
-  # Defaults to ["openid"].
-  #
-  #scopes: ["openid", "profile"]
-
-  # the oauth2 authorization endpoint. Required if provider discovery is disabled.
-  #
-  #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
-
-  # the oauth2 token endpoint. Required if provider discovery is disabled.
-  #
-  #token_endpoint: "https://accounts.example.com/oauth2/token"
-
-  # the OIDC userinfo endpoint. Required if discovery is disabled and the
-  # "openid" scope is not requested.
-  #
-  #userinfo_endpoint: "https://accounts.example.com/userinfo"
-
-  # URI where to fetch the JWKS. Required if discovery is disabled and the
-  # "openid" scope is used.
-  #
-  #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
-
-  # Uncomment to skip metadata verification. Defaults to false.
-  #
-  # Use this if you are connecting to a provider that is not OpenID Connect
-  # compliant.
-  # Avoid this in production.
-  #
-  #skip_verification: true
-
-  # Whether to fetch the user profile from the userinfo endpoint. Valid
-  # values are: "auto" or "userinfo_endpoint".
-  #
-  # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
-  # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
-  #
-  #user_profile_method: "userinfo_endpoint"
-
-  # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
-  # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
-  #
-  #allow_existing_users: true
-
-  # An external module can be provided here as a custom solution to mapping
-  # attributes returned from a OIDC provider onto a matrix user.
-  #
-  user_mapping_provider:
-    # The custom module's class. Uncomment to use a custom module.
-    # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
-    #
-    # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
-    # for information on implementing a custom mapping provider.
-    #
-    #module: mapping_provider.OidcMappingProvider
-
-    # Custom configuration values for the module. This section will be passed as
-    # a Python dictionary to the user mapping provider module's `parse_config`
-    # method.
-    #
-    # The examples below are intended for the default provider: they should be
-    # changed if using a custom provider.
-    #
-    config:
-      # name of the claim containing a unique identifier for the user.
-      # Defaults to `sub`, which OpenID Connect compliant providers should provide.
-      #
-      #subject_claim: "sub"
-
-      # Jinja2 template for the localpart of the MXID.
-      #
-      # When rendering, this template is given the following variables:
-      #   * user: The claims returned by the UserInfo Endpoint and/or in the ID
-      #     Token
-      #
-      # This must be configured if using the default mapping provider.
-      #
-      localpart_template: "{{ user.preferred_username }}"
-
-      # Jinja2 template for the display name to set on first login.
-      #
-      # If unset, no displayname will be set.
-      #
-      #display_name_template: "{{ user.given_name }} {{ user.last_name }}"
-
-      # Jinja2 templates for extra attributes to send back to the client during
-      # login.
-      #
-      # Note that these are non-standard and clients will ignore them without modifications.
-      #
-      #extra_attributes:
-        #birthdate: "{{ user.birthdate }}"
-
+#   idp_id: a unique identifier for this identity provider. Used internally
+#       by Synapse; should be a single word such as 'github'.
+#
+#       Note that, if this is changed, users authenticating via that provider
+#       will no longer be recognised as the same user!
+#
+#   idp_name: A user-facing name for this identity provider, which is used to
+#       offer the user a choice of login mechanisms.
+#
+#   idp_icon: An optional icon for this identity provider, which is presented
+#       by clients and Synapse's own IdP picker page. If given, must be an
+#       MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
+#       obtain such an MXC URI is to upload an image to an (unencrypted) room
+#       and then copy the "url" from the source of the event.)
+#
+#   idp_brand: An optional brand for this identity provider, allowing clients
+#       to style the login flow according to the identity provider in question.
+#       See the spec for possible options here.
+#
+#   discover: set to 'false' to disable the use of the OIDC discovery mechanism
+#       to discover endpoints. Defaults to true.
+#
+#   issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
+#       is enabled) to discover the provider's endpoints.
+#
+#   client_id: Required. oauth2 client id to use.
+#
+#   client_secret: Required. oauth2 client secret to use.
+#
+#   client_auth_method: auth method to use when exchanging the token. Valid
+#       values are 'client_secret_basic' (default), 'client_secret_post' and
+#       'none'.
+#
+#   scopes: list of scopes to request. This should normally include the "openid"
+#       scope. Defaults to ["openid"].
+#
+#   authorization_endpoint: the oauth2 authorization endpoint. Required if
+#       provider discovery is disabled.
+#
+#   token_endpoint: the oauth2 token endpoint. Required if provider discovery is
+#       disabled.
+#
+#   userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
+#       disabled and the 'openid' scope is not requested.
+#
+#   jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
+#       the 'openid' scope is used.
+#
+#   skip_verification: set to 'true' to skip metadata verification. Use this if
+#       you are connecting to a provider that is not OpenID Connect compliant.
+#       Defaults to false. Avoid this in production.
+#
+#   user_profile_method: Whether to fetch the user profile from the userinfo
+#       endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+#
+#       Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
+#       included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+#       userinfo endpoint.
+#
+#   allow_existing_users: set to 'true' to allow a user logging in via OIDC to
+#       match a pre-existing account instead of failing. This could be used if
+#       switching from password logins to OIDC. Defaults to false.
+#
+#   user_mapping_provider: Configuration for how attributes returned from a OIDC
+#       provider are mapped onto a matrix user. This setting has the following
+#       sub-properties:
+#
+#       module: The class name of a custom mapping module. Default is
+#           'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'.
+#           See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+#           for information on implementing a custom mapping provider.
+#
+#       config: Configuration for the mapping provider module. This section will
+#           be passed as a Python dictionary to the user mapping provider
+#           module's `parse_config` method.
+#
+#           For the default provider, the following settings are available:
+#
+#             subject_claim: name of the claim containing a unique identifier
+#                 for the user. Defaults to 'sub', which OpenID Connect
+#                 compliant providers should provide.
+#
+#             localpart_template: Jinja2 template for the localpart of the MXID.
+#                 If this is not set, the user will be prompted to choose their
+#                 own username (see 'sso_auth_account_details.html' in the 'sso'
+#                 section of this file).
+#
+#             display_name_template: Jinja2 template for the display name to set
+#                 on first login. If unset, no displayname will be set.
+#
+#             email_template: Jinja2 template for the email address of the user.
+#                 If unset, no email address will be added to the account.
+#
+#             extra_attributes: a map of Jinja2 templates for extra attributes
+#                 to send back to the client during login.
+#                 Note that these are non-standard and clients will ignore them
+#                 without modifications.
+#
+#           When rendering, the Jinja2 templates are given a 'user' variable,
+#           which is set to the claims returned by the UserInfo Endpoint and/or
+#           in the ID Token.
+#
+# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
+# for information on how to configure these options.
+#
+# For backwards compatibility, it is also possible to configure a single OIDC
+# provider via an 'oidc_config' setting. This is now deprecated and admins are
+# advised to migrate to the 'oidc_providers' format. (When doing that migration,
+# use 'oidc' for the idp_id to ensure that existing users continue to be
+# recognised.)
+#
+oidc_providers:
+  # Generic example
+  #
+  #- idp_id: my_idp
+  #  idp_name: "My OpenID provider"
+  #  idp_icon: "mxc://example.com/mediaid"
+  #  discover: false
+  #  issuer: "https://accounts.example.com/"
+  #  client_id: "provided-by-your-issuer"
+  #  client_secret: "provided-by-your-issuer"
+  #  client_auth_method: client_secret_post
+  #  scopes: ["openid", "profile"]
+  #  authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+  #  token_endpoint: "https://accounts.example.com/oauth2/token"
+  #  userinfo_endpoint: "https://accounts.example.com/userinfo"
+  #  jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+  #  skip_verification: true
+  #  user_mapping_provider:
+  #    config:
+  #      subject_claim: "id"
+  #      localpart_template: "{ user.login }"
+  #      display_name_template: "{ user.name }"
+  #      email_template: "{ user.email }"
+
+  # For use with Keycloak
+  #
+  #- idp_id: keycloak
+  #  idp_name: Keycloak
+  #  issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
+  #  client_id: "synapse"
+  #  client_secret: "copy secret generated in Keycloak UI"
+  #  scopes: ["openid", "profile"]
+
+  # For use with Github
+  #
+  #- idp_id: github
+  #  idp_name: Github
+  #  idp_brand: org.matrix.github
+  #  discover: false
+  #  issuer: "https://github.com/"
+  #  client_id: "your-client-id" # TO BE FILLED
+  #  client_secret: "your-client-secret" # TO BE FILLED
+  #  authorization_endpoint: "https://github.com/login/oauth/authorize"
+  #  token_endpoint: "https://github.com/login/oauth/access_token"
+  #  userinfo_endpoint: "https://api.github.com/user"
+  #  scopes: ["read:user"]
+  #  user_mapping_provider:
+  #    config:
+  #      subject_claim: "id"
+  #      localpart_template: "{ user.login }"
+  #      display_name_template: "{ user.name }"
 
 
 # Enable Central Authentication Service (CAS) for registration and login.
@@ -1836,10 +1910,6 @@ cas_config:
   #
   #server_url: "https://cas-server.com"
 
-  # The public URL of the homeserver.
-  #
-  #service_url: "https://homeserver.domain.com:8448"
-
   # The attribute of the CAS response to use as the display name.
   #
   # If unset, no displayname will be set.
@@ -1882,41 +1952,135 @@ sso:
     #  - https://my.custom.client/
 
     # Directory in which Synapse will try to find the template files below.
-    # If not set, default templates from within the Synapse package will be used.
-    #
-    # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
-    # If you *do* uncomment it, you will need to make sure that all the templates
-    # below are in the directory.
+    # If not set, or the files named below are not found within the template
+    # directory, default templates from within the Synapse package will be used.
     #
     # Synapse will look for the following templates in this directory:
     #
+    # * HTML page to prompt the user to choose an Identity Provider during
+    #   login: 'sso_login_idp_picker.html'.
+    #
+    #   This is only used if multiple SSO Identity Providers are configured.
+    #
+    #   When rendering, this template is given the following variables:
+    #     * redirect_url: the URL that the user will be redirected to after
+    #       login.
+    #
+    #     * server_name: the homeserver's name.
+    #
+    #     * providers: a list of available Identity Providers. Each element is
+    #       an object with the following attributes:
+    #
+    #         * idp_id: unique identifier for the IdP
+    #         * idp_name: user-facing name for the IdP
+    #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+    #              for the IdP
+    #         * idp_brand: if specified in the IdP config, a textual identifier
+    #              for the brand of the IdP
+    #
+    #   The rendered HTML page should contain a form which submits its results
+    #   back as a GET request, with the following query parameters:
+    #
+    #     * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+    #       to the template)
+    #
+    #     * idp: the 'idp_id' of the chosen IDP.
+    #
+    # * HTML page to prompt new users to enter a userid and confirm other
+    #   details: 'sso_auth_account_details.html'. This is only shown if the
+    #   SSO implementation (with any user_mapping_provider) does not return
+    #   a localpart.
+    #
+    #   When rendering, this template is given the following variables:
+    #
+    #     * server_name: the homeserver's name.
+    #
+    #     * idp: details of the SSO Identity Provider that the user logged in
+    #       with: an object with the following attributes:
+    #
+    #         * idp_id: unique identifier for the IdP
+    #         * idp_name: user-facing name for the IdP
+    #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+    #              for the IdP
+    #         * idp_brand: if specified in the IdP config, a textual identifier
+    #              for the brand of the IdP
+    #
+    #     * user_attributes: an object containing details about the user that
+    #       we received from the IdP. May have the following attributes:
+    #
+    #         * display_name: the user's display_name
+    #         * emails: a list of email addresses
+    #
+    #   The template should render a form which submits the following fields:
+    #
+    #     * username: the localpart of the user's chosen user id
+    #
+    # * HTML page allowing the user to consent to the server's terms and
+    #   conditions. This is only shown for new users, and only if
+    #   `user_consent.require_at_registration` is set.
+    #
+    #   When rendering, this template is given the following variables:
+    #
+    #     * server_name: the homeserver's name.
+    #
+    #     * user_id: the user's matrix proposed ID.
+    #
+    #     * user_profile.display_name: the user's proposed display name, if any.
+    #
+    #     * consent_version: the version of the terms that the user will be
+    #       shown
+    #
+    #     * terms_url: a link to the page showing the terms.
+    #
+    #   The template should render a form which submits the following fields:
+    #
+    #     * accepted_version: the version of the terms accepted by the user
+    #       (ie, 'consent_version' from the input variables).
+    #
     # * HTML page for a confirmation step before redirecting back to the client
     #   with the login token: 'sso_redirect_confirm.html'.
     #
-    #   When rendering, this template is given three variables:
-    #     * redirect_url: the URL the user is about to be redirected to. Needs
-    #                     manual escaping (see
-    #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+    #   When rendering, this template is given the following variables:
+    #
+    #     * redirect_url: the URL the user is about to be redirected to.
     #
     #     * display_url: the same as `redirect_url`, but with the query
     #                    parameters stripped. The intention is to have a
     #                    human-readable URL to show to users, not to use it as
-    #                    the final address to redirect to. Needs manual escaping
-    #                    (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+    #                    the final address to redirect to.
     #
     #     * server_name: the homeserver's name.
     #
+    #     * new_user: a boolean indicating whether this is the user's first time
+    #          logging in.
+    #
+    #     * user_id: the user's matrix ID.
+    #
+    #     * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
+    #           None if the user has not set an avatar.
+    #
+    #     * user_profile.display_name: the user's display name. None if the user
+    #           has not set a display name.
+    #
     # * HTML page which notifies the user that they are authenticating to confirm
     #   an operation on their account during the user interactive authentication
     #   process: 'sso_auth_confirm.html'.
     #
     #   When rendering, this template is given the following variables:
-    #     * redirect_url: the URL the user is about to be redirected to. Needs
-    #                     manual escaping (see
-    #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+    #     * redirect_url: the URL the user is about to be redirected to.
     #
     #     * description: the operation which the user is being asked to confirm
     #
+    #     * idp: details of the Identity Provider that we will use to confirm
+    #       the user's identity: an object with the following attributes:
+    #
+    #         * idp_id: unique identifier for the IdP
+    #         * idp_name: user-facing name for the IdP
+    #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+    #              for the IdP
+    #         * idp_brand: if specified in the IdP config, a textual identifier
+    #              for the brand of the IdP
+    #
     # * HTML page shown after a successful user interactive authentication session:
     #   'sso_auth_success.html'.
     #
@@ -1925,6 +2089,14 @@ sso:
     #
     #   This template has no additional variables.
     #
+    # * HTML page shown after a user-interactive authentication session which
+    #   does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+    #
+    #   When rendering, this template is given the following variables:
+    #     * server_name: the homeserver's name.
+    #     * user_id_to_verify: the MXID of the user that we are trying to
+    #       validate.
+    #
     # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
     #   attempts to login: 'sso_account_deactivated.html'.
     #
@@ -2050,6 +2222,21 @@ password_config:
       #
       #require_uppercase: true
 
+ui_auth:
+    # The number of milliseconds to allow a user-interactive authentication
+    # session to be active.
+    #
+    # This defaults to 0, meaning the user is queried for their credentials
+    # before every action, but this can be overridden to alow a single
+    # validation to be re-used.  This weakens the protections afforded by
+    # the user-interactive authentication process, by allowing for multiple
+    # (and potentially different) operations to use the same validation session.
+    #
+    # Uncomment below to allow for credential validation to last for 15
+    # seconds.
+    #
+    #session_timeout: 15000
+
 
 # Configuration for sending emails from Synapse.
 #
@@ -2115,10 +2302,15 @@ email:
   #
   #validation_token_lifetime: 15m
 
-  # Directory in which Synapse will try to find the template files below.
-  # If not set, default templates from within the Synapse package will be used.
+  # The web client location to direct users to during an invite. This is passed
+  # to the identity server as the org.matrix.web_client_location key. Defaults
+  # to unset, giving no guidance to the identity server.
   #
-  # Do not uncomment this setting unless you want to customise the templates.
+  #invite_client_location: https://app.element.io
+
+  # Directory in which Synapse will try to find the template files below.
+  # If not set, or the files named below are not found within the template
+  # directory, default templates from within the Synapse package will be used.
   #
   # Synapse will look for the following templates in this directory:
   #
@@ -2327,7 +2519,7 @@ spam_checker:
 # If enabled, non server admins can only create groups with local parts
 # starting with this prefix
 #
-#group_creation_prefix: "unofficial/"
+#group_creation_prefix: "unofficial_"
 
 
 
@@ -2592,6 +2784,13 @@ opentracing:
 #
 #run_background_tasks_on: worker1
 
+# A shared secret used by the replication APIs to authenticate HTTP requests
+# from workers.
+#
+# By default this is unused and traffic is not authenticated.
+#
+#worker_replication_secret: ""
+
 
 # Configuration for Redis when using workers. This *must* be enabled when
 # using workers (unless using old style direct TCP configuration).
diff --git a/docs/spam_checker.md b/docs/spam_checker.md
index 7fc08f1b70..5b4f6428e6 100644
--- a/docs/spam_checker.md
+++ b/docs/spam_checker.md
@@ -22,6 +22,8 @@ well as some specific methods:
 * `user_may_create_room`
 * `user_may_create_room_alias`
 * `user_may_publish_room`
+* `check_username_for_spam`
+* `check_registration_for_spam`
 
 The details of the each of these methods (as well as their inputs and outputs)
 are documented in the `synapse.events.spamcheck.SpamChecker` class.
@@ -32,28 +34,33 @@ call back into the homeserver internals.
 ### Example
 
 ```python
+from synapse.spam_checker_api import RegistrationBehaviour
+
 class ExampleSpamChecker:
     def __init__(self, config, api):
         self.config = config
         self.api = api
 
-    def check_event_for_spam(self, foo):
+    async def check_event_for_spam(self, foo):
         return False  # allow all events
 
-    def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
         return True  # allow all invites
 
-    def user_may_create_room(self, userid):
+    async def user_may_create_room(self, userid):
         return True  # allow all room creations
 
-    def user_may_create_room_alias(self, userid, room_alias):
+    async def user_may_create_room_alias(self, userid, room_alias):
         return True  # allow all room aliases
 
-    def user_may_publish_room(self, userid, room_id):
+    async def user_may_publish_room(self, userid, room_id):
         return True  # allow publishing of all rooms
 
-    def check_username_for_spam(self, user_profile):
+    async def check_username_for_spam(self, user_profile):
         return False  # allow all usernames
+
+    async def check_registration_for_spam(self, email_threepid, username, request_info):
+        return RegistrationBehaviour.ALLOW  # allow all registrations
 ```
 
 ## Configuration
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index ab2a648910..e1d6ede7ba 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -15,12 +15,18 @@ where SAML mapping providers come into play.
 SSO mapping providers are currently supported for OpenID and SAML SSO
 configurations. Please see the details below for how to implement your own.
 
-It is the responsibility of the mapping provider to normalise the SSO attributes
-and map them to a valid Matrix ID. The
-[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers)
-has some information about what is considered valid. Alternately an easy way to
-ensure it is valid is to use a Synapse utility function:
-`synapse.types.map_username_to_mxid_localpart`.
+It is up to the mapping provider whether the user should be assigned a predefined
+Matrix ID based on the SSO attributes, or if the user should be allowed to
+choose their own username.
+
+In the first case - where users are automatically allocated a Matrix ID - it is
+the responsibility of the mapping provider to normalise the SSO attributes and
+map them to a valid Matrix ID. The [specification for Matrix
+IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some
+information about what is considered valid.
+
+If the mapping provider does not assign a Matrix ID, then Synapse will
+automatically serve an HTML page allowing the user to pick their own username.
 
 External mapping providers are provided to Synapse in the form of an external
 Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere,
@@ -80,8 +86,9 @@ A custom mapping provider must specify the following methods:
                      with failures=1. The method should then return a different
                      `localpart` value, such as `john.doe1`.
     - Returns a dictionary with two keys:
-      - localpart: A required string, used to generate the Matrix ID.
-      - displayname: An optional string, the display name for the user.
+      - `localpart`: A string, used to generate the Matrix ID. If this is
+        `None`, the user is prompted to pick their own username.
+      - `displayname`: An optional string, the display name for the user.
 * `get_extra_attributes(self, userinfo, token)`
     - This method must be async.
     - Arguments:
@@ -116,11 +123,13 @@ comment these options out and use those specified by the module instead.
 
 A custom mapping provider must specify the following methods:
 
-* `__init__(self, parsed_config)`
+* `__init__(self, parsed_config, module_api)`
    - Arguments:
      - `parsed_config` - A configuration object that is the return value of the
        `parse_config` method. You should set any configuration options needed by
        the module here.
+     - `module_api` - a `synapse.module_api.ModuleApi` object which provides the
+       stable API available for extension modules.
 * `parse_config(config)`
     - This method should have the `@staticmethod` decoration.
     - Arguments:
@@ -163,12 +172,13 @@ A custom mapping provider must specify the following methods:
                                 redirected to.
     - This method must return a dictionary, which will then be used by Synapse
       to build a new user. The following keys are allowed:
-       * `mxid_localpart` - Required. The mxid localpart of the new user.
+       * `mxid_localpart` - The mxid localpart of the new user.  If this is
+         `None`, the user is prompted to pick their own username.
        * `displayname` - The displayname of the new user. If not provided, will default to
                          the value of `mxid_localpart`.
        * `emails` - A list of emails for the new user. If not provided, will
                     default to an empty list.
-       
+
        Alternatively it can raise a `synapse.api.errors.RedirectException` to
        redirect the user to another page. This is useful to prompt the user for
        additional information, e.g. if you want them to provide their own username.
diff --git a/docs/systemd-with-workers/README.md b/docs/systemd-with-workers/README.md
index 8e57d4f62e..cfa36be7b4 100644
--- a/docs/systemd-with-workers/README.md
+++ b/docs/systemd-with-workers/README.md
@@ -31,7 +31,7 @@ There is no need for a separate configuration file for the master process.
 1. Adjust synapse configuration files as above.
 1. Copy the `*.service` and `*.target` files in [system](system) to
 `/etc/systemd/system`.
-1. Run `systemctl deamon-reload` to tell systemd to load the new unit files.
+1. Run `systemctl daemon-reload` to tell systemd to load the new unit files.
 1. Run `systemctl enable matrix-synapse.service`. This will configure the
 synapse master process to be started as part of the `matrix-synapse.target`
 target.
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index a470c274a5..e8f13ad484 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -232,6 +232,12 @@ Here are a few things to try:
 
    (Understanding the output is beyond the scope of this document!)
 
+ * You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
+   Note that this test is not fully reliable yet, so don't be discouraged if
+   the test fails.
+   [Here](https://github.com/matrix-org/voip-tester) is the github repo of the
+   source of the tester, where you can file bug reports.
+
  * There is a WebRTC test tool at
    https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
    use it, you will need a username/password for your TURN server. You can
diff --git a/docs/workers.md b/docs/workers.md
index c53d1bd2ff..f7fc6df119 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
 be used for demo purposes and any admin considering workers should already be
 running PostgreSQL.
 
+See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
+for a higher level overview.
+
 ## Main process/worker communication
 
 The processes communicate with each other via a Synapse-specific protocol called
@@ -37,6 +40,9 @@ which relays replication commands between processes. This can give a significant
 cpu saving on the main process and will be a prerequisite for upcoming
 performance improvements.
 
+If Redis support is enabled Synapse will use it as a shared cache, as well as a
+pub/sub mechanism.
+
 See the [Architectural diagram](#architectural-diagram) section at the end for
 a visualisation of what this looks like.
 
@@ -56,7 +62,7 @@ The appropriate dependencies must also be installed for Synapse. If using a
 virtualenv, these can be installed with:
 
 ```sh
-pip install matrix-synapse[redis]
+pip install "matrix-synapse[redis]"
 ```
 
 Note that these dependencies are included when synapse is installed with `pip
@@ -89,7 +95,8 @@ shared configuration file.
 Normally, only a couple of changes are needed to make an existing configuration
 file suitable for use with workers. First, you need to enable an "HTTP replication
 listener" for the main process; and secondly, you need to enable redis-based
-replication. For example:
+replication. Optionally, a shared secret can be used to authenticate HTTP
+traffic between workers. For example:
 
 
 ```yaml
@@ -103,6 +110,9 @@ listeners:
     resources:
      - names: [replication]
 
+# Add a random shared secret to authenticate traffic.
+worker_replication_secret: ""
+
 redis:
     enabled: true
 ```
@@ -210,6 +220,7 @@ expressions:
     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
     ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
+    ^/_matrix/client/(api/v1|r0|unstable)/devices$
     ^/_matrix/client/(api/v1|r0|unstable)/keys/query$
     ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
     ^/_matrix/client/versions$
@@ -217,7 +228,6 @@ expressions:
     ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
     ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
     ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
-    ^/_synapse/client/password_reset/email/submit_token$
 
     # Registration/login requests
     ^/_matrix/client/(api/v1|r0|unstable)/login$
@@ -225,6 +235,7 @@ expressions:
     ^/_matrix/client/(r0|unstable)/auth/.*/fallback/web$
 
     # Event sending requests
+    ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
@@ -247,25 +258,29 @@ Additionally, the following endpoints should be included if Synapse is configure
 to use SSO (you only need to include the ones for whichever SSO provider you're
 using):
 
+    # for all SSO providers
+    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect
+    ^/_synapse/client/pick_idp$
+    ^/_synapse/client/pick_username
+    ^/_synapse/client/new_user_consent$
+    ^/_synapse/client/sso_register$
+
     # OpenID Connect requests.
-    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
-    ^/_synapse/oidc/callback$
+    ^/_synapse/client/oidc/callback$
 
     # SAML requests.
-    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
-    ^/_matrix/saml2/authn_response$
+    ^/_synapse/client/saml2/authn_response$
 
     # CAS requests.
-    ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
     ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
 
-Note that a HTTP listener with `client` and `federation` resources must be
-configured in the `worker_listeners` option in the worker config.
-
-Ensure that all SSO logins go to a single process (usually the main process). 
+Ensure that all SSO logins go to a single process.
 For multiple workers not handling the SSO endpoints properly, see
 [#7530](https://github.com/matrix-org/synapse/issues/7530).
 
+Note that a HTTP listener with `client` and `federation` resources must be
+configured in the `worker_listeners` option in the worker config.
+
 #### Load balancing
 
 It is possible to run multiple instances of this worker app, with incoming requests
diff --git a/mypy.ini b/mypy.ini
index 3c8d303064..68a4533973 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -7,44 +7,23 @@ show_error_codes = True
 show_traceback = True
 mypy_path = stubs
 warn_unreachable = True
+
+# To find all folders that pass mypy you run:
+#
+#   find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null"  \; -print
+
 files =
   scripts-dev/sign_json,
   synapse/api,
   synapse/appservice,
   synapse/config,
+  synapse/crypto,
   synapse/event_auth.py,
   synapse/events/builder.py,
   synapse/events/validator.py,
   synapse/events/spamcheck.py,
   synapse/federation,
-  synapse/handlers/_base.py,
-  synapse/handlers/account_data.py,
-  synapse/handlers/account_validity.py,
-  synapse/handlers/appservice.py,
-  synapse/handlers/auth.py,
-  synapse/handlers/cas_handler.py,
-  synapse/handlers/deactivate_account.py,
-  synapse/handlers/device.py,
-  synapse/handlers/devicemessage.py,
-  synapse/handlers/directory.py,
-  synapse/handlers/events.py,
-  synapse/handlers/federation.py,
-  synapse/handlers/identity.py,
-  synapse/handlers/initial_sync.py,
-  synapse/handlers/message.py,
-  synapse/handlers/oidc_handler.py,
-  synapse/handlers/pagination.py,
-  synapse/handlers/password_policy.py,
-  synapse/handlers/presence.py,
-  synapse/handlers/profile.py,
-  synapse/handlers/read_marker.py,
-  synapse/handlers/register.py,
-  synapse/handlers/room.py,
-  synapse/handlers/room_member.py,
-  synapse/handlers/room_member_worker.py,
-  synapse/handlers/saml_handler.py,
-  synapse/handlers/sync.py,
-  synapse/handlers/ui_auth,
+  synapse/handlers,
   synapse/http/client.py,
   synapse/http/federation/matrix_federation_agent.py,
   synapse/http/federation/well_known_resolver.py,
@@ -55,32 +34,45 @@ files =
   synapse/metrics,
   synapse/module_api,
   synapse/notifier.py,
-  synapse/push/pusherpool.py,
-  synapse/push/push_rule_evaluator.py,
+  synapse/push,
   synapse/replication,
   synapse/rest,
   synapse/server.py,
   synapse/server_notices,
   synapse/spam_checker_api,
   synapse/state,
+  synapse/storage/__init__.py,
+  synapse/storage/_base.py,
+  synapse/storage/background_updates.py,
   synapse/storage/databases/main/appservice.py,
   synapse/storage/databases/main/events.py,
+  synapse/storage/databases/main/keys.py,
+  synapse/storage/databases/main/pusher.py,
   synapse/storage/databases/main/registration.py,
   synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,
   synapse/storage/database.py,
   synapse/storage/engines,
+  synapse/storage/keys.py,
   synapse/storage/persist_events.py,
+  synapse/storage/prepare_database.py,
+  synapse/storage/purge_events.py,
+  synapse/storage/push_rule.py,
+  synapse/storage/relations.py,
+  synapse/storage/roommember.py,
   synapse/storage/state.py,
+  synapse/storage/types.py,
   synapse/storage/util,
   synapse/streams,
   synapse/types.py,
   synapse/util/async_helpers.py,
   synapse/util/caches,
   synapse/util/metrics.py,
+  synapse/util/stringutils.py,
   tests/replication,
   tests/test_utils,
   tests/handlers/test_password_providers.py,
+  tests/rest/client/v1/test_login.py,
   tests/rest/client/v2_alpha/test_auth.py,
   tests/util/test_stream_change_cache.py
 
@@ -108,6 +100,9 @@ ignore_missing_imports = True
 [mypy-h11]
 ignore_missing_imports = True
 
+[mypy-msgpack]
+ignore_missing_imports = True
+
 [mypy-opentracing]
 ignore_missing_imports = True
 
@@ -167,3 +162,9 @@ ignore_missing_imports = True
 
 [mypy-hiredis]
 ignore_missing_imports = True
+
+[mypy-josepy.*]
+ignore_missing_imports = True
+
+[mypy-txacme.*]
+ignore_missing_imports = True
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index f328ab57d5..fe2965cd36 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -80,7 +80,8 @@ else
   # then lint everything!
   if [[ -z ${files+x} ]]; then
     # Lint all source code files and directories
-    files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
+    # Note: this list aims the mirror the one in tox.ini
+    files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
   fi
 fi
 
diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py
index 5882f3a0b0..f7f18805e4 100644
--- a/scripts-dev/mypy_synapse_plugin.py
+++ b/scripts-dev/mypy_synapse_plugin.py
@@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
     ) -> Optional[Callable[[MethodSigContext], CallableType]]:
         if fullname.startswith(
             "synapse.util.caches.descriptors._CachedFunction.__call__"
+        ) or fullname.startswith(
+            "synapse.util.caches.descriptors._LruCachedFunction.__call__"
         ):
             return cached_function_method_signature
         return None
diff --git a/scripts/generate_log_config b/scripts/generate_log_config
index b6957f48a3..a13a5634a3 100755
--- a/scripts/generate_log_config
+++ b/scripts/generate_log_config
@@ -40,4 +40,6 @@ if __name__ == "__main__":
     )
 
     args = parser.parse_args()
-    args.output_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file))
+    out = args.output_file
+    out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file))
+    out.flush()
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 5ad17aa90f..69bf9110a6 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
 
 BOOLEAN_COLUMNS = {
     "events": ["processed", "outlier", "contains_url"],
-    "rooms": ["is_public"],
+    "rooms": ["is_public", "has_auth_chain_index"],
     "event_edges": ["is_state"],
     "presence_list": ["accepted"],
     "presence_stream": ["currently_active"],
@@ -629,6 +629,7 @@ class Porter(object):
             await self._setup_state_group_id_seq()
             await self._setup_user_id_seq()
             await self._setup_events_stream_seqs()
+            await self._setup_device_inbox_seq()
 
             # Step 3. Get tables.
             self.progress.set_state("Fetching tables")
@@ -911,6 +912,32 @@ class Porter(object):
             "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
         )
 
+    async def _setup_device_inbox_seq(self):
+        """Set the device inbox sequence to the correct value.
+        """
+        curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="device_inbox",
+            keyvalues={},
+            retcol="COALESCE(MAX(stream_id), 1)",
+            allow_none=True,
+        )
+
+        curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="device_federation_outbox",
+            keyvalues={},
+            retcol="COALESCE(MAX(stream_id), 1)",
+            allow_none=True,
+        )
+
+        next_id = max(curr_local_id, curr_federation_id) + 1
+
+        def r(txn):
+            txn.execute(
+                "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,)
+            )
+
+        return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r)
+
 
 ##############################################
 # The following is simply UI stuff
diff --git a/setup.py b/setup.py
index 9730afb41b..99425d52de 100755
--- a/setup.py
+++ b/setup.py
@@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
 #
 # We pin black so that our tests don't start failing on new releases.
 CONDITIONAL_REQUIREMENTS["lint"] = [
-    "isort==5.0.3",
+    "isort==5.7.0",
     "black==19.10b0",
     "flake8-comprehensions",
     "flake8",
@@ -121,6 +121,7 @@ setup(
     include_package_data=True,
     zip_safe=False,
     long_description=long_description,
+    long_description_content_type="text/x-rst",
     python_requires="~=3.5",
     classifiers=[
         "Development Status :: 5 - Production/Stable",
diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi
index 3f3af59f26..0368ba4703 100644
--- a/stubs/frozendict.pyi
+++ b/stubs/frozendict.pyi
@@ -15,16 +15,7 @@
 
 # Stub for frozendict.
 
-from typing import (
-    Any,
-    Hashable,
-    Iterable,
-    Iterator,
-    Mapping,
-    overload,
-    Tuple,
-    TypeVar,
-)
+from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
 
 _KT = TypeVar("_KT", bound=Hashable)  # Key type.
 _VT = TypeVar("_VT")  # Value type.
diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi
index 68779f968e..7b9fd079d9 100644
--- a/stubs/sortedcontainers/sorteddict.pyi
+++ b/stubs/sortedcontainers/sorteddict.pyi
@@ -7,17 +7,17 @@ from typing import (
     Callable,
     Dict,
     Hashable,
-    Iterator,
-    Iterable,
     ItemsView,
+    Iterable,
+    Iterator,
     KeysView,
     List,
     Mapping,
     Optional,
     Sequence,
+    Tuple,
     Type,
     TypeVar,
-    Tuple,
     Union,
     ValuesView,
     overload,
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 522244bb57..618548a305 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -15,13 +15,23 @@
 
 """Contains *incomplete* type hints for txredisapi.
 """
-
-from typing import List, Optional, Union, Type
+from typing import Any, List, Optional, Type, Union
 
 class RedisProtocol:
     def publish(self, channel: str, message: bytes): ...
+    async def ping(self) -> None: ...
+    async def set(
+        self,
+        key: str,
+        value: Any,
+        expire: Optional[int] = None,
+        pexpire: Optional[int] = None,
+        only_if_not_exists: bool = False,
+        only_if_exists: bool = False,
+    ) -> None: ...
+    async def get(self, key: str) -> Any: ...
 
-class SubscriberProtocol:
+class SubscriberProtocol(RedisProtocol):
     def __init__(self, *args, **kwargs): ...
     password: Optional[str]
     def subscribe(self, channels: Union[str, List[str]]): ...
@@ -40,14 +50,13 @@ def lazyConnection(
     convertNumbers: bool = ...,
 ) -> RedisProtocol: ...
 
-class SubscriberFactory:
-    def buildProtocol(self, addr): ...
-
 class ConnectionHandler: ...
 
 class RedisFactory:
     continueTrying: bool
     handler: RedisProtocol
+    pool: List[RedisProtocol]
+    replyTimeout: Optional[int]
     def __init__(
         self,
         uuid: str,
@@ -60,3 +69,7 @@ class RedisFactory:
         replyTimeout: Optional[int] = None,
         convertNumbers: Optional[int] = True,
     ): ...
+    def buildProtocol(self, addr) -> RedisProtocol: ...
+
+class SubscriberFactory(RedisFactory):
+    def __init__(self): ...
diff --git a/synapse/__init__.py b/synapse/__init__.py
index f2d3ac68eb..359276427f 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.24.0"
+__version__ = "1.27.0"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index bfcaf68b2a..67ecbd32ff 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,7 @@ from twisted.web.server import Request
 import synapse.types
 from synapse import event_auth
 from synapse.api.auth_blocking import AuthBlocking
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -31,7 +31,10 @@ from synapse.api.errors import (
     MissingClientTokenError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.appservice import ApplicationService
 from synapse.events import EventBase
+from synapse.http import get_request_user_agent
+from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import StateMap, UserID
@@ -184,8 +187,8 @@ class Auth:
             AuthError if access is denied for the user in the access token
         """
         try:
-            ip_addr = self.hs.get_ip_from_request(request)
-            user_agent = request.get_user_agent("")
+            ip_addr = request.getClientIP()
+            user_agent = get_request_user_agent(request)
 
             access_token = self.get_access_token_from_request(request)
 
@@ -273,7 +276,7 @@ class Auth:
             return None, None
 
         if app_service.ip_range_whitelist:
-            ip_address = IPAddress(self.hs.get_ip_from_request(request))
+            ip_address = IPAddress(request.getClientIP())
             if ip_address not in app_service.ip_range_whitelist:
                 return None, None
 
@@ -474,7 +477,7 @@ class Auth:
         now = self.hs.get_clock().time_msec()
         return now < expiry
 
-    def get_appservice_by_req(self, request):
+    def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
         if not service:
@@ -646,7 +649,8 @@ class Auth:
             )
             if (
                 visibility
-                and visibility.content["history_visibility"] == "world_readable"
+                and visibility.content.get("history_visibility")
+                == HistoryVisibility.WORLD_READABLE
             ):
                 return Membership.JOIN, None
             raise AuthError(
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 9c227218e0..d8088f524a 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -36,6 +36,7 @@ class AuthBlocking:
         self._limit_usage_by_mau = hs.config.limit_usage_by_mau
         self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
         self._server_name = hs.hostname
+        self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
 
     async def check_auth_blocking(
         self,
@@ -76,6 +77,12 @@ class AuthBlocking:
                 # We never block the server from doing actions on behalf of
                 # users.
                 return
+            elif requester.app_service and not self._track_appservice_user_ips:
+                # If we're authenticated as an appservice then we only block
+                # auth if `track_appservice_user_ips` is set, as that option
+                # implicitly means that application services are part of MAU
+                # limits.
+                return
 
         # Never fail an auth check for the server notices users or support user
         # This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 592abd844b..565a8cd76a 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -95,6 +95,8 @@ class EventTypes:
 
     Presence = "m.presence"
 
+    Dummy = "org.matrix.dummy_event"
+
 
 class RejectedReason:
     AUTH_ERROR = "auth_error"
@@ -160,3 +162,10 @@ class RoomEncryptionAlgorithms:
 class AccountDataTypes:
     DIRECT = "m.direct"
     IGNORED_USER_LIST = "m.ignored_user_list"
+
+
+class HistoryVisibility:
+    INVITED = "invited"
+    JOINED = "joined"
+    SHARED = "shared"
+    WORLD_READABLE = "world_readable"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index f3ecbf36b6..de2cc15d33 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -51,11 +51,11 @@ class RoomDisposition:
 class RoomVersion:
     """An object which describes the unique attributes of a room version."""
 
-    identifier = attr.ib()  # str; the identifier for this version
-    disposition = attr.ib()  # str; one of the RoomDispositions
-    event_format = attr.ib()  # int; one of the EventFormatVersions
-    state_res = attr.ib()  # int; one of the StateResolutionVersions
-    enforce_key_validity = attr.ib()  # bool
+    identifier = attr.ib(type=str)  # the identifier for this version
+    disposition = attr.ib(type=str)  # one of the RoomDispositions
+    event_format = attr.ib(type=int)  # one of the EventFormatVersions
+    state_res = attr.ib(type=int)  # one of the StateResolutionVersions
+    enforce_key_validity = attr.ib(type=bool)
 
     # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
     special_case_aliases_auth = attr.ib(type=bool)
@@ -64,9 +64,11 @@ class RoomVersion:
     # * Floats
     # * NaN, Infinity, -Infinity
     strict_canonicaljson = attr.ib(type=bool)
-    # bool: MSC2209: Check 'notifications' key while verifying
+    # MSC2209: Check 'notifications' key while verifying
     # m.room.power_levels auth rules.
     limit_notifications_power_levels = attr.ib(type=bool)
+    # MSC2174/MSC2176: Apply updated redaction rules algorithm.
+    msc2176_redaction_rules = attr.ib(type=bool)
 
 
 class RoomVersions:
@@ -79,6 +81,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V2 = RoomVersion(
         "2",
@@ -89,6 +92,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V3 = RoomVersion(
         "3",
@@ -99,6 +103,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V4 = RoomVersion(
         "4",
@@ -109,6 +114,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V5 = RoomVersion(
         "5",
@@ -119,6 +125,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V6 = RoomVersion(
         "6",
@@ -129,6 +136,18 @@ class RoomVersions:
         special_case_aliases_auth=False,
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+    )
+    MSC2176 = RoomVersion(
+        "org.matrix.msc2176",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=True,
     )
 
 
@@ -141,5 +160,6 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V4,
         RoomVersions.V5,
         RoomVersions.V6,
+        RoomVersions.MSC2176,
     )
 }  # type: Dict[str, RoomVersion]
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 895b38ae76..9840a9d55b 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2017 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,11 +16,12 @@
 import gc
 import logging
 import os
+import platform
 import signal
 import socket
 import sys
 import traceback
-from typing import Iterable
+from typing import Awaitable, Callable, Iterable
 
 from typing_extensions import NoReturn
 
@@ -143,6 +145,45 @@ def quit_with_error(error_string: str) -> NoReturn:
     sys.exit(1)
 
 
+def register_start(cb: Callable[..., Awaitable], *args, **kwargs) -> None:
+    """Register a callback with the reactor, to be called once it is running
+
+    This can be used to initialise parts of the system which require an asynchronous
+    setup.
+
+    Any exception raised by the callback will be printed and logged, and the process
+    will exit.
+    """
+
+    async def wrapper():
+        try:
+            await cb(*args, **kwargs)
+        except Exception:
+            # previously, we used Failure().printTraceback() here, in the hope that
+            # would give better tracebacks than traceback.print_exc(). However, that
+            # doesn't handle chained exceptions (with a __cause__ or __context__) well,
+            # and I *think* the need for Failure() is reduced now that we mostly use
+            # async/await.
+
+            # Write the exception to both the logs *and* the unredirected stderr,
+            # because people tend to get confused if it only goes to one or the other.
+            #
+            # One problem with this is that if people are using a logging config that
+            # logs to the console (as is common eg under docker), they will get two
+            # copies of the exception. We could maybe try to detect that, but it's
+            # probably a cost we can bear.
+            logger.fatal("Error during startup", exc_info=True)
+            print("Error during startup:", file=sys.__stderr__)
+            traceback.print_exc(file=sys.__stderr__)
+
+            # it's no use calling sys.exit here, since that just raises a SystemExit
+            # exception which is then caught by the reactor, and everything carries
+            # on as normal.
+            os._exit(1)
+
+    reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
+
+
 def listen_metrics(bind_addresses, port):
     """
     Start Prometheus metrics server.
@@ -227,7 +268,7 @@ def refresh_certificate(hs):
         logger.info("Context factories updated.")
 
 
-def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
+async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
     """
     Start a Synapse server or worker.
 
@@ -241,71 +282,67 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
         hs: homeserver instance
         listeners: Listener configuration ('listeners' in homeserver.yaml)
     """
-    try:
-        # Set up the SIGHUP machinery.
-        if hasattr(signal, "SIGHUP"):
+    # Set up the SIGHUP machinery.
+    if hasattr(signal, "SIGHUP"):
+        reactor = hs.get_reactor()
 
-            @wrap_as_background_process("sighup")
-            def handle_sighup(*args, **kwargs):
-                # Tell systemd our state, if we're using it. This will silently fail if
-                # we're not using systemd.
-                sdnotify(b"RELOADING=1")
+        @wrap_as_background_process("sighup")
+        def handle_sighup(*args, **kwargs):
+            # Tell systemd our state, if we're using it. This will silently fail if
+            # we're not using systemd.
+            sdnotify(b"RELOADING=1")
 
-                for i, args, kwargs in _sighup_callbacks:
-                    i(*args, **kwargs)
+            for i, args, kwargs in _sighup_callbacks:
+                i(*args, **kwargs)
 
-                sdnotify(b"READY=1")
+            sdnotify(b"READY=1")
 
-            # We defer running the sighup handlers until next reactor tick. This
-            # is so that we're in a sane state, e.g. flushing the logs may fail
-            # if the sighup happens in the middle of writing a log entry.
-            def run_sighup(*args, **kwargs):
-                hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
+        # We defer running the sighup handlers until next reactor tick. This
+        # is so that we're in a sane state, e.g. flushing the logs may fail
+        # if the sighup happens in the middle of writing a log entry.
+        def run_sighup(*args, **kwargs):
+            # `callFromThread` should be "signal safe" as well as thread
+            # safe.
+            reactor.callFromThread(handle_sighup, *args, **kwargs)
 
-            signal.signal(signal.SIGHUP, run_sighup)
+        signal.signal(signal.SIGHUP, run_sighup)
 
-            register_sighup(refresh_certificate, hs)
+        register_sighup(refresh_certificate, hs)
 
-        # Load the certificate from disk.
-        refresh_certificate(hs)
+    # Load the certificate from disk.
+    refresh_certificate(hs)
 
-        # Start the tracer
-        synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
-            hs
-        )
+    # Start the tracer
+    synapse.logging.opentracing.init_tracer(  # type: ignore[attr-defined] # noqa
+        hs
+    )
 
-        # It is now safe to start your Synapse.
-        hs.start_listening(listeners)
-        hs.get_datastore().db_pool.start_profiling()
-        hs.get_pusherpool().start()
+    # It is now safe to start your Synapse.
+    hs.start_listening(listeners)
+    hs.get_datastore().db_pool.start_profiling()
+    hs.get_pusherpool().start()
 
-        # Log when we start the shut down process.
-        hs.get_reactor().addSystemEventTrigger(
-            "before", "shutdown", logger.info, "Shutting down..."
-        )
+    # Log when we start the shut down process.
+    hs.get_reactor().addSystemEventTrigger(
+        "before", "shutdown", logger.info, "Shutting down..."
+    )
 
-        setup_sentry(hs)
-        setup_sdnotify(hs)
-
-        # If background tasks are running on the main process, start collecting the
-        # phone home stats.
-        if hs.config.run_background_tasks:
-            start_phone_stats_home(hs)
-
-        # We now freeze all allocated objects in the hopes that (almost)
-        # everything currently allocated are things that will be used for the
-        # rest of time. Doing so means less work each GC (hopefully).
-        #
-        # This only works on Python 3.7
-        if sys.version_info >= (3, 7):
-            gc.collect()
-            gc.freeze()
-    except Exception:
-        traceback.print_exc(file=sys.stderr)
-        reactor = hs.get_reactor()
-        if reactor.running:
-            reactor.stop()
-        sys.exit(1)
+    setup_sentry(hs)
+    setup_sdnotify(hs)
+
+    # If background tasks are running on the main process, start collecting the
+    # phone home stats.
+    if hs.config.run_background_tasks:
+        start_phone_stats_home(hs)
+
+    # We now freeze all allocated objects in the hopes that (almost)
+    # everything currently allocated are things that will be used for the
+    # rest of time. Doing so means less work each GC (hopefully).
+    #
+    # This only works on Python 3.7
+    if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7):
+        gc.collect()
+        gc.freeze()
 
 
 def setup_sentry(hs):
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1b511890aa..516f2464b4 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -21,7 +21,8 @@ from typing import Dict, Iterable, Optional, Set
 
 from typing_extensions import ContextManager
 
-from twisted.internet import address, reactor
+from twisted.internet import address
+from twisted.web.resource import IResource
 
 import synapse
 import synapse.events
@@ -34,6 +35,7 @@ from synapse.api.urls import (
     SERVER_KEY_V2_PREFIX,
 )
 from synapse.app import _base
+from synapse.app._base import register_start
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
@@ -89,45 +91,47 @@ from synapse.replication.tcp.streams import (
     ToDeviceStream,
 )
 from synapse.rest.admin import register_servlets_for_media_repo
-from synapse.rest.client.v1 import events
+from synapse.rest.client.v1 import events, login, room
 from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
-from synapse.rest.client.v1.login import LoginRestServlet
 from synapse.rest.client.v1.profile import (
     ProfileAvatarURLRestServlet,
     ProfileDisplaynameRestServlet,
     ProfileRestServlet,
 )
 from synapse.rest.client.v1.push_rule import PushRuleRestServlet
-from synapse.rest.client.v1.room import (
-    JoinedRoomMemberListRestServlet,
-    JoinRoomAliasServlet,
-    PublicRoomListRestServlet,
-    RoomEventContextServlet,
-    RoomInitialSyncRestServlet,
-    RoomMemberListRestServlet,
-    RoomMembershipRestServlet,
-    RoomMessageListRestServlet,
-    RoomSendEventRestServlet,
-    RoomStateEventRestServlet,
-    RoomStateRestServlet,
-    RoomTypingRestServlet,
-)
 from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha import groups, sync, user_directory
+from synapse.rest.client.v2_alpha import (
+    account_data,
+    groups,
+    read_marker,
+    receipts,
+    room_keys,
+    sync,
+    tags,
+    user_directory,
+)
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
 from synapse.rest.client.v2_alpha.account_data import (
     AccountDataServlet,
     RoomAccountDataServlet,
 )
-from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
+from synapse.rest.client.v2_alpha.devices import DevicesRestServlet
+from synapse.rest.client.v2_alpha.keys import (
+    KeyChangesServlet,
+    KeyQueryServlet,
+    OneTimeKeyServlet,
+)
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.server import HomeServer, cache_in_self
 from synapse.storage.databases.main.censor_events import CensorEventsStore
 from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
+from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore
 from synapse.storage.databases.main.media_repository import MediaRepositoryStore
 from synapse.storage.databases.main.metrics import ServerMetricsStore
 from synapse.storage.databases.main.monthly_active_users import (
@@ -266,7 +270,6 @@ class GenericWorkerPresence(BasePresenceHandler):
         super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
-        self.http_client = hs.get_simple_http_client()
 
         self._presence_enabled = hs.config.use_presence
 
@@ -460,6 +463,7 @@ class GenericWorkerSlavedStore(
     UserDirectoryStore,
     StatsStore,
     UIAuthWorkerStore,
+    EndToEndRoomKeyStore,
     SlavedDeviceInboxStore,
     SlavedDeviceStore,
     SlavedReceiptsStore,
@@ -504,7 +508,7 @@ class GenericWorkerServer(HomeServer):
             site_tag = port
 
         # We always include a health resource.
-        resources = {"/health": HealthResource()}
+        resources = {"/health": HealthResource()}  # type: Dict[str, IResource]
 
         for res in listener_config.http_options.resources:
             for name in res.names:
@@ -513,36 +517,36 @@ class GenericWorkerServer(HomeServer):
                 elif name == "client":
                     resource = JsonResource(self, canonical_json=False)
 
-                    PublicRoomListRestServlet(self).register(resource)
-                    RoomMemberListRestServlet(self).register(resource)
-                    JoinedRoomMemberListRestServlet(self).register(resource)
-                    RoomStateRestServlet(self).register(resource)
-                    RoomEventContextServlet(self).register(resource)
-                    RoomMessageListRestServlet(self).register(resource)
                     RegisterRestServlet(self).register(resource)
-                    LoginRestServlet(self).register(resource)
+                    login.register_servlets(self, resource)
                     ThreepidRestServlet(self).register(resource)
+                    DevicesRestServlet(self).register(resource)
                     KeyQueryServlet(self).register(resource)
+                    OneTimeKeyServlet(self).register(resource)
                     KeyChangesServlet(self).register(resource)
                     VoipRestServlet(self).register(resource)
                     PushRuleRestServlet(self).register(resource)
                     VersionsRestServlet(self).register(resource)
-                    RoomSendEventRestServlet(self).register(resource)
-                    RoomMembershipRestServlet(self).register(resource)
-                    RoomStateEventRestServlet(self).register(resource)
-                    JoinRoomAliasServlet(self).register(resource)
+
                     ProfileAvatarURLRestServlet(self).register(resource)
                     ProfileDisplaynameRestServlet(self).register(resource)
                     ProfileRestServlet(self).register(resource)
                     KeyUploadServlet(self).register(resource)
                     AccountDataServlet(self).register(resource)
                     RoomAccountDataServlet(self).register(resource)
-                    RoomTypingRestServlet(self).register(resource)
 
                     sync.register_servlets(self, resource)
                     events.register_servlets(self, resource)
+                    room.register_servlets(self, resource, True)
+                    room.register_deprecated_servlets(self, resource)
                     InitialSyncRestServlet(self).register(resource)
-                    RoomInitialSyncRestServlet(self).register(resource)
+                    room_keys.register_servlets(self, resource)
+                    tags.register_servlets(self, resource)
+                    account_data.register_servlets(self, resource)
+                    receipts.register_servlets(self, resource)
+                    read_marker.register_servlets(self, resource)
+
+                    SendToDeviceRestServlet(self).register(resource)
 
                     user_directory.register_servlets(self, resource)
 
@@ -554,6 +558,8 @@ class GenericWorkerServer(HomeServer):
                     groups.register_servlets(self, resource)
 
                     resources.update({CLIENT_API_PREFIX: resource})
+
+                    resources.update(build_synapse_client_resource_tree(self))
                 elif name == "federation":
                     resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
                 elif name == "media":
@@ -981,9 +987,7 @@ def start(config_options):
     # streams. Will no-op if no streams can be written to by this worker.
     hs.get_replication_streamer()
 
-    reactor.addSystemEventTrigger(
-        "before", "startup", _base.start, hs, config.worker_listeners
-    )
+    register_start(_base.start, hs, config.worker_listeners)
 
     _base.start_worker_reactor("synapse-generic-worker", config)
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2b5465417f..244657cb88 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -15,15 +15,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import gc
 import logging
 import os
 import sys
-from typing import Iterable
+from typing import Iterable, Iterator
 
-from twisted.application import service
-from twisted.internet import defer, reactor
-from twisted.python.failure import Failure
+from twisted.internet import reactor
 from twisted.web.resource import EncodingResourceWrapper, IResource
 from twisted.web.server import GzipEncoderFactory
 from twisted.web.static import File
@@ -40,7 +37,7 @@ from synapse.api.urls import (
     WEB_CLIENT_PREFIX,
 )
 from synapse.app import _base
-from synapse.app._base import listen_ssl, listen_tcp, quit_with_error
+from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
 from synapse.config._base import ConfigError
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
@@ -63,6 +60,7 @@ from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -71,7 +69,6 @@ from synapse.storage.prepare_database import UpgradeDatabaseException
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.module_loader import load_module
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse.app.homeserver")
@@ -90,7 +87,7 @@ class SynapseHomeServer(HomeServer):
         tls = listener_config.tls
         site_tag = listener_config.http_options.tag
         if site_tag is None:
-            site_tag = port
+            site_tag = str(port)
 
         # We always include a health resource.
         resources = {"/health": HealthResource()}
@@ -107,7 +104,10 @@ class SynapseHomeServer(HomeServer):
         logger.debug("Configuring additional resources: %r", additional_resources)
         module_api = self.get_module_api()
         for path, resmodule in additional_resources.items():
-            handler_cls, config = load_module(resmodule)
+            handler_cls, config = load_module(
+                resmodule,
+                ("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
+            )
             handler = handler_cls(config, module_api)
             if IResource.providedBy(handler):
                 resource = handler
@@ -189,19 +189,10 @@ class SynapseHomeServer(HomeServer):
                     "/_matrix/client/versions": client_resource,
                     "/.well-known/matrix/client": WellKnownResource(self),
                     "/_synapse/admin": AdminRestResource(self),
+                    **build_synapse_client_resource_tree(self),
                 }
             )
 
-            if self.get_config().oidc_enabled:
-                from synapse.rest.oidc import OIDCResource
-
-                resources["/_synapse/oidc"] = OIDCResource(self)
-
-            if self.get_config().saml2_enabled:
-                from synapse.rest.saml2 import SAML2Resource
-
-                resources["/_matrix/saml2"] = SAML2Resource(self)
-
             if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL:
                 from synapse.rest.synapse.client.password_reset import (
                     PasswordResetSubmitTokenResource,
@@ -342,7 +333,10 @@ def setup(config_options):
             "Synapse Homeserver", config_options
         )
     except ConfigError as e:
-        sys.stderr.write("\nERROR: %s\n" % (e,))
+        sys.stderr.write("\n")
+        for f in format_config_error(e):
+            sys.stderr.write(f)
+        sys.stderr.write("\n")
         sys.exit(1)
 
     if not config:
@@ -407,61 +401,62 @@ def setup(config_options):
             _base.refresh_certificate(hs)
 
     async def start():
-        try:
-            # Run the ACME provisioning code, if it's enabled.
-            if hs.config.acme_enabled:
-                acme = hs.get_acme_handler()
-                # Start up the webservices which we will respond to ACME
-                # challenges with, and then provision.
-                await acme.start_listening()
-                await do_acme()
-
-                # Check if it needs to be reprovisioned every day.
-                hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
+        # Run the ACME provisioning code, if it's enabled.
+        if hs.config.acme_enabled:
+            acme = hs.get_acme_handler()
+            # Start up the webservices which we will respond to ACME
+            # challenges with, and then provision.
+            await acme.start_listening()
+            await do_acme()
 
-            # Load the OIDC provider metadatas, if OIDC is enabled.
-            if hs.config.oidc_enabled:
-                oidc = hs.get_oidc_handler()
-                # Loading the provider metadata also ensures the provider config is valid.
-                await oidc.load_metadata()
-                await oidc.load_jwks()
+            # Check if it needs to be reprovisioned every day.
+            hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
 
-            _base.start(hs, config.listeners)
+        # Load the OIDC provider metadatas, if OIDC is enabled.
+        if hs.config.oidc_enabled:
+            oidc = hs.get_oidc_handler()
+            # Loading the provider metadata also ensures the provider config is valid.
+            await oidc.load_metadata()
 
-            hs.get_datastore().db_pool.updates.start_doing_background_updates()
-        except Exception:
-            # Print the exception and bail out.
-            print("Error during startup:", file=sys.stderr)
+        await _base.start(hs, config.listeners)
 
-            # this gives better tracebacks than traceback.print_exc()
-            Failure().printTraceback(file=sys.stderr)
+        hs.get_datastore().db_pool.updates.start_doing_background_updates()
 
-            if reactor.running:
-                reactor.stop()
-            sys.exit(1)
-
-    reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
+    register_start(start)
 
     return hs
 
 
-class SynapseService(service.Service):
+def format_config_error(e: ConfigError) -> Iterator[str]:
     """
-    A twisted Service class that will start synapse. Used to run synapse
-    via twistd and a .tac.
+    Formats a config error neatly
+
+    The idea is to format the immediate error, plus the "causes" of those errors,
+    hopefully in a way that makes sense to the user. For example:
+
+        Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+          Failed to parse config for module 'JinjaOidcMappingProvider':
+            invalid jinja template:
+              unexpected end of template, expected 'end of print statement'.
+
+    Args:
+        e: the error to be formatted
+
+    Returns: An iterator which yields string fragments to be formatted
     """
+    yield "Error in configuration"
 
-    def __init__(self, config):
-        self.config = config
+    if e.path:
+        yield " at '%s'" % (".".join(e.path),)
 
-    def startService(self):
-        hs = setup(self.config)
-        change_resource_limit(hs.config.soft_file_limit)
-        if hs.config.gc_thresholds:
-            gc.set_threshold(*hs.config.gc_thresholds)
+    yield ":\n  %s" % (e.msg,)
 
-    def stopService(self):
-        return self._port.stopListening()
+    e = e.__cause__
+    indent = 1
+    while e:
+        indent += 1
+        yield ":\n%s%s" % ("  " * indent, str(e))
+        e = e.__cause__
 
 
 def run(hs):
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index c38cf8231f..8f86cecb76 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
 
     stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
     stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
+    daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms()
+    stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms
+    stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages()
+    daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages()
+    stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages
     stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
     stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
+    daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
+    stats["daily_sent_messages"] = daily_sent_messages
 
     r30_results = await hs.get_datastore().count_r30_users()
     for name, count in r30_results.items():
         stats["r30_users_" + name] = count
 
-    daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
-    stats["daily_sent_messages"] = daily_sent_messages
     stats["cache_factor"] = hs.config.caches.global_factor
     stats["event_cache_size"] = hs.config.caches.event_cache_size
 
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 85f65da4d9..a851f8801d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,21 +18,31 @@
 import argparse
 import errno
 import os
-import time
-import urllib.parse
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
-from typing import Any, Callable, List, MutableMapping, Optional
+from typing import Any, Iterable, List, MutableMapping, Optional
 
 import attr
 import jinja2
 import pkg_resources
 import yaml
 
+from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter
+
 
 class ConfigError(Exception):
-    pass
+    """Represents a problem parsing the configuration
+
+    Args:
+        msg:  A textual description of the error.
+        path: Where appropriate, an indication of where in the configuration
+           the problem lies.
+    """
+
+    def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+        self.msg = msg
+        self.path = path
 
 
 # We split these messages out to allow packages to override with package
@@ -193,11 +203,28 @@ class Config:
         with open(file_path) as file_stream:
             return file_stream.read()
 
+    def read_template(self, filename: str) -> jinja2.Template:
+        """Load a template file from disk.
+
+        This function will attempt to load the given template from the default Synapse
+        template directory.
+
+        Files read are treated as Jinja templates. The templates is not rendered yet
+        and has autoescape enabled.
+
+        Args:
+            filename: A template filename to read.
+
+        Raises:
+            ConfigError: if the file's path is incorrect or otherwise cannot be read.
+
+        Returns:
+            A jinja2 template.
+        """
+        return self.read_templates([filename])[0]
+
     def read_templates(
-        self,
-        filenames: List[str],
-        custom_template_directory: Optional[str] = None,
-        autoescape: bool = False,
+        self, filenames: List[str], custom_template_directory: Optional[str] = None,
     ) -> List[jinja2.Template]:
         """Load a list of template files from disk using the given variables.
 
@@ -205,7 +232,8 @@ class Config:
         template directory. If `custom_template_directory` is supplied, that directory
         is tried first.
 
-        Files read are treated as Jinja templates. These templates are not rendered yet.
+        Files read are treated as Jinja templates. The templates are not rendered yet
+        and have autoescape enabled.
 
         Args:
             filenames: A list of template filenames to read.
@@ -213,16 +241,12 @@ class Config:
             custom_template_directory: A directory to try to look for the templates
                 before using the default Synapse template directory instead.
 
-            autoescape: Whether to autoescape variables before inserting them into the
-                template.
-
         Raises:
             ConfigError: if the file's path is incorrect or otherwise cannot be read.
 
         Returns:
             A list of jinja2 templates.
         """
-        templates = []
         search_directories = [self.default_template_dir]
 
         # The loader will first look in the custom template directory (if specified) for the
@@ -238,54 +262,20 @@ class Config:
             # Search the custom template directory as well
             search_directories.insert(0, custom_template_directory)
 
+        # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(search_directories)
-        env = jinja2.Environment(loader=loader, autoescape=autoescape)
+        env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
 
         # Update the environment with our custom filters
-        env.filters.update({"format_ts": _format_ts_filter})
-        if self.public_baseurl:
-            env.filters.update(
-                {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)}
-            )
-
-        for filename in filenames:
-            # Load the template
-            template = env.get_template(filename)
-            templates.append(template)
-
-        return templates
-
-
-def _format_ts_filter(value: int, format: str):
-    return time.strftime(format, time.localtime(value / 1000))
-
-
-def _create_mxc_to_http_filter(public_baseurl: str) -> Callable:
-    """Create and return a jinja2 filter that converts MXC urls to HTTP
-
-    Args:
-        public_baseurl: The public, accessible base URL of the homeserver
-    """
-
-    def mxc_to_http_filter(value, width, height, resize_method="crop"):
-        if value[0:6] != "mxc://":
-            return ""
-
-        server_and_media_id = value[6:]
-        fragment = None
-        if "#" in server_and_media_id:
-            server_and_media_id, fragment = server_and_media_id.split("#", 1)
-            fragment = "#" + fragment
-
-        params = {"width": width, "height": height, "method": resize_method}
-        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
-            public_baseurl,
-            server_and_media_id,
-            urllib.parse.urlencode(params),
-            fragment or "",
+        env.filters.update(
+            {
+                "format_ts": _format_ts_filter,
+                "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl),
+            }
         )
 
-    return mxc_to_http_filter
+        # Load the templates
+        return [env.get_template(filename) for filename in filenames]
 
 
 class RootConfig:
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index b8faafa9bd..70025b5d60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,23 +1,25 @@
-from typing import Any, List, Optional
+from typing import Any, Iterable, List, Optional
 
 from synapse.config import (
     api,
     appservice,
+    auth,
     captcha,
     cas,
     consent_config,
     database,
     emailconfig,
+    experimental,
     groups,
     jwt_config,
     key,
     logger,
     metrics,
     oidc_config,
-    password,
     password_auth_providers,
     push,
     ratelimiting,
+    redis,
     registration,
     repository,
     room_directory,
@@ -35,7 +37,10 @@ from synapse.config import (
     workers,
 )
 
-class ConfigError(Exception): ...
+class ConfigError(Exception):
+    def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+        self.msg = msg
+        self.path = path
 
 MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
 MISSING_REPORT_STATS_SPIEL: str
@@ -45,10 +50,11 @@ def path_exists(file_path: str): ...
 
 class RootConfig:
     server: server.ServerConfig
+    experimental: experimental.ExperimentalConfig
     tls: tls.TlsConfig
     database: database.DatabaseConfig
     logging: logger.LoggingConfig
-    ratelimit: ratelimiting.RatelimitConfig
+    ratelimiting: ratelimiting.RatelimitConfig
     media: repository.ContentRepositoryConfig
     captcha: captcha.CaptchaConfig
     voip: voip.VoipConfig
@@ -62,7 +68,7 @@ class RootConfig:
     sso: sso.SSOConfig
     oidc: oidc_config.OIDCConfig
     jwt: jwt_config.JWTConfig
-    password: password.PasswordConfig
+    auth: auth.AuthConfig
     email: emailconfig.EmailConfig
     worker: workers.WorkerConfig
     authproviders: password_auth_providers.PasswordAuthProviderConfig
@@ -76,6 +82,7 @@ class RootConfig:
     roomdirectory: room_directory.RoomDirectoryConfig
     thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
     tracer: tracer.TracerConfig
+    redis: redis.RedisConfig
 
     config_classes: List = ...
     def __init__(self) -> None: ...
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index c74969a977..8fce7f6bb1 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -38,14 +38,27 @@ def validate_config(
     try:
         jsonschema.validate(config, json_schema)
     except jsonschema.ValidationError as e:
-        # copy `config_path` before modifying it.
-        path = list(config_path)
-        for p in list(e.path):
-            if isinstance(p, int):
-                path.append("<item %i>" % p)
-            else:
-                path.append(str(p))
-
-        raise ConfigError(
-            "Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
-        )
+        raise json_error_to_config_error(e, config_path)
+
+
+def json_error_to_config_error(
+    e: jsonschema.ValidationError, config_path: Iterable[str]
+) -> ConfigError:
+    """Converts a json validation error to a user-readable ConfigError
+
+    Args:
+        e: the exception to be converted
+        config_path: the path within the config file. This will be used as a basis
+           for the error message.
+
+    Returns:
+        a ConfigError
+    """
+    # copy `config_path` before modifying it.
+    path = list(config_path)
+    for p in list(e.absolute_path):
+        if isinstance(p, int):
+            path.append("<item %i>" % p)
+        else:
+            path.append(str(p))
+    return ConfigError(e.message, path)
diff --git a/synapse/config/password.py b/synapse/config/auth.py
index 9c0ea8c30a..2b3e2ce87b 100644
--- a/synapse/config/password.py
+++ b/synapse/config/auth.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,11 +17,11 @@
 from ._base import Config
 
 
-class PasswordConfig(Config):
-    """Password login configuration
+class AuthConfig(Config):
+    """Password and login configuration
     """
 
-    section = "password"
+    section = "auth"
 
     def read_config(self, config, **kwargs):
         password_config = config.get("password_config", {})
@@ -35,6 +36,10 @@ class PasswordConfig(Config):
         self.password_policy = password_config.get("policy") or {}
         self.password_policy_enabled = self.password_policy.get("enabled", False)
 
+        # User-interactive authentication
+        ui_auth = config.get("ui_auth") or {}
+        self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0)
+
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
         password_config:
@@ -87,4 +92,19 @@ class PasswordConfig(Config):
               # Defaults to 'false'.
               #
               #require_uppercase: true
+
+        ui_auth:
+            # The number of milliseconds to allow a user-interactive authentication
+            # session to be active.
+            #
+            # This defaults to 0, meaning the user is queried for their credentials
+            # before every action, but this can be overridden to alow a single
+            # validation to be re-used.  This weakens the protections afforded by
+            # the user-interactive authentication process, by allowing for multiple
+            # (and potentially different) operations to use the same validation session.
+            #
+            # Uncomment below to allow for credential validation to last for 15
+            # seconds.
+            #
+            #session_timeout: 15000
         """
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index cb00958165..9e48f865cc 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -28,9 +28,7 @@ class CaptchaConfig(Config):
             "recaptcha_siteverify_api",
             "https://www.recaptcha.net/recaptcha/api/siteverify",
         )
-        self.recaptcha_template = self.read_templates(
-            ["recaptcha.html"], autoescape=True
-        )[0]
+        self.recaptcha_template = self.read_template("recaptcha.html")
 
     def generate_config_section(self, **kwargs):
         return """\
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 2f97e6d258..aaa7eba110 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import Config
+from ._base import Config, ConfigError
 
 
 class CasConfig(Config):
@@ -30,7 +30,15 @@ class CasConfig(Config):
 
         if self.cas_enabled:
             self.cas_server_url = cas_config["server_url"]
-            self.cas_service_url = cas_config["service_url"]
+
+            # The public baseurl is required because it is used by the redirect
+            # template.
+            public_baseurl = self.public_baseurl
+            if not public_baseurl:
+                raise ConfigError("cas_config requires a public_baseurl to be set")
+
+            # TODO Update this to a _synapse URL.
+            self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket"
             self.cas_displayname_attribute = cas_config.get("displayname_attribute")
             self.cas_required_attributes = cas_config.get("required_attributes") or {}
         else:
@@ -40,7 +48,7 @@ class CasConfig(Config):
             self.cas_required_attributes = {}
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
-        return """
+        return """\
         # Enable Central Authentication Service (CAS) for registration and login.
         #
         cas_config:
@@ -53,10 +61,6 @@ class CasConfig(Config):
           #
           #server_url: "https://cas-server.com"
 
-          # The public URL of the homeserver.
-          #
-          #service_url: "https://homeserver.domain.com:8448"
-
           # The attribute of the CAS response to use as the display name.
           #
           # If unset, no displayname will be set.
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 6efa59b110..c47f364b14 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -89,7 +89,7 @@ class ConsentConfig(Config):
 
     def read_config(self, config, **kwargs):
         consent_config = config.get("user_consent")
-        self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0]
+        self.terms_template = self.read_template("terms.html")
 
         if consent_config is None:
             return
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index cceffbfee2..d4328c46b9 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -322,6 +322,22 @@ class EmailConfig(Config):
 
         self.email_subjects = EmailSubjectConfig(**subjects)
 
+        # The invite client location should be a HTTP(S) URL or None.
+        self.invite_client_location = email_config.get("invite_client_location") or None
+        if self.invite_client_location:
+            if not isinstance(self.invite_client_location, str):
+                raise ConfigError(
+                    "Config option email.invite_client_location must be type str"
+                )
+            if not (
+                self.invite_client_location.startswith("http://")
+                or self.invite_client_location.startswith("https://")
+            ):
+                raise ConfigError(
+                    "Config option email.invite_client_location must be a http or https URL",
+                    path=("email", "invite_client_location"),
+                )
+
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return (
             """\
@@ -389,10 +405,15 @@ class EmailConfig(Config):
           #
           #validation_token_lifetime: 15m
 
-          # Directory in which Synapse will try to find the template files below.
-          # If not set, default templates from within the Synapse package will be used.
+          # The web client location to direct users to during an invite. This is passed
+          # to the identity server as the org.matrix.web_client_location key. Defaults
+          # to unset, giving no guidance to the identity server.
           #
-          # Do not uncomment this setting unless you want to customise the templates.
+          #invite_client_location: https://app.element.io
+
+          # Directory in which Synapse will try to find the template files below.
+          # If not set, or the files named below are not found within the template
+          # directory, default templates from within the Synapse package will be used.
           #
           # Synapse will look for the following templates in this directory:
           #
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
new file mode 100644
index 0000000000..b1c1c51e4d
--- /dev/null
+++ b/synapse/config/experimental.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config
+from synapse.types import JsonDict
+
+
+class ExperimentalConfig(Config):
+    """Config section for enabling experimental features"""
+
+    section = "experimental"
+
+    def read_config(self, config: JsonDict, **kwargs):
+        experimental = config.get("experimental_features") or {}
+
+        # MSC2858 (multiple SSO identity providers)
+        self.msc2858_enabled = experimental.get("msc2858_enabled", False)  # type: bool
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index ffd8fca54e..9f3c57e6a1 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -12,12 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from typing import Optional
 
-from netaddr import IPSet
-
-from synapse.config._base import Config, ConfigError
+from synapse.config._base import Config
 from synapse.config._util import validate_config
 
 
@@ -36,23 +33,6 @@ class FederationConfig(Config):
             for domain in federation_domain_whitelist:
                 self.federation_domain_whitelist[domain] = True
 
-        self.federation_ip_range_blacklist = config.get(
-            "federation_ip_range_blacklist", []
-        )
-
-        # Attempt to create an IPSet from the given ranges
-        try:
-            self.federation_ip_range_blacklist = IPSet(
-                self.federation_ip_range_blacklist
-            )
-
-            # Always blacklist 0.0.0.0, ::
-            self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
-        except Exception as e:
-            raise ConfigError(
-                "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
-            )
-
         federation_metrics_domains = config.get("federation_metrics_domains") or []
         validate_config(
             _METRICS_FOR_DOMAINS_SCHEMA,
@@ -76,27 +56,6 @@ class FederationConfig(Config):
         #  - nyc.example.com
         #  - syd.example.com
 
-        # Prevent federation requests from being sent to the following
-        # blacklist IP address CIDR ranges. If this option is not specified, or
-        # specified with an empty list, no ip range blacklist will be enforced.
-        #
-        # As of Synapse v1.4.0 this option also affects any outbound requests to identity
-        # servers provided by user input.
-        #
-        # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
-        # listed here, since they correspond to unroutable addresses.)
-        #
-        federation_ip_range_blacklist:
-          - '127.0.0.0/8'
-          - '10.0.0.0/8'
-          - '172.16.0.0/12'
-          - '192.168.0.0/16'
-          - '100.64.0.0/10'
-          - '169.254.0.0/16'
-          - '::1/128'
-          - 'fe80::/64'
-          - 'fc00::/7'
-
         # Report prometheus metrics on the age of PDUs being sent to and received from
         # the following domains. This can be used to give an idea of "delay" on inbound
         # and outbound federation, though be aware that any delay can be due to problems
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index d6862d9a64..7b7860ea71 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -32,5 +32,5 @@ class GroupsConfig(Config):
         # If enabled, non server admins can only create groups with local parts
         # starting with this prefix
         #
-        #group_creation_prefix: "unofficial/"
+        #group_creation_prefix: "unofficial_"
         """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index be65554524..64a2429f77 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -17,12 +17,14 @@
 from ._base import RootConfig
 from .api import ApiConfig
 from .appservice import AppServiceConfig
+from .auth import AuthConfig
 from .cache import CacheConfig
 from .captcha import CaptchaConfig
 from .cas import CasConfig
 from .consent_config import ConsentConfig
 from .database import DatabaseConfig
 from .emailconfig import EmailConfig
+from .experimental import ExperimentalConfig
 from .federation import FederationConfig
 from .groups import GroupsConfig
 from .jwt_config import JWTConfig
@@ -30,7 +32,6 @@ from .key import KeyConfig
 from .logger import LoggingConfig
 from .metrics import MetricsConfig
 from .oidc_config import OIDCConfig
-from .password import PasswordConfig
 from .password_auth_providers import PasswordAuthProviderConfig
 from .push import PushConfig
 from .ratelimiting import RatelimitConfig
@@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
 
     config_classes = [
         ServerConfig,
+        ExperimentalConfig,
         TlsConfig,
         FederationConfig,
         CacheConfig,
@@ -76,7 +78,7 @@ class HomeServerConfig(RootConfig):
         CasConfig,
         SSOConfig,
         JWTConfig,
-        PasswordConfig,
+        AuthConfig,
         EmailConfig,
         PasswordAuthProviderConfig,
         PushConfig,
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index d4e887a3e0..4df3f93c1c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
     # filter options, but care must when using e.g. MemoryHandler to buffer
     # writes.
 
-    log_context_filter = LoggingContextFilter(request="")
+    log_context_filter = LoggingContextFilter()
     log_metadata_filter = MetadataFilter({"server_name": config.server_name})
     old_factory = logging.getLogRecordFactory()
 
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 69d188341c..4d0f24a9d5 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2020 Quentin Gliech
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,8 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from collections import Counter
+from typing import Iterable, Optional, Tuple, Type
+
+import attr
+
+from synapse.config._util import validate_config
 from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.types import Collection, JsonDict
 from synapse.util.module_loader import load_module
+from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 from ._base import Config, ConfigError
 
@@ -25,201 +34,457 @@ class OIDCConfig(Config):
     section = "oidc"
 
     def read_config(self, config, **kwargs):
-        self.oidc_enabled = False
-
-        oidc_config = config.get("oidc_config")
-
-        if not oidc_config or not oidc_config.get("enabled", False):
+        self.oidc_providers = tuple(_parse_oidc_provider_configs(config))
+        if not self.oidc_providers:
             return
 
         try:
             check_requirements("oidc")
         except DependencyException as e:
-            raise ConfigError(e.message)
+            raise ConfigError(e.message) from e
+
+        # check we don't have any duplicate idp_ids now. (The SSO handler will also
+        # check for duplicates when the REST listeners get registered, but that happens
+        # after synapse has forked so doesn't give nice errors.)
+        c = Counter([i.idp_id for i in self.oidc_providers])
+        for idp_id, count in c.items():
+            if count > 1:
+                raise ConfigError(
+                    "Multiple OIDC providers have the idp_id %r." % idp_id
+                )
 
         public_baseurl = self.public_baseurl
         if public_baseurl is None:
             raise ConfigError("oidc_config requires a public_baseurl to be set")
-        self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
-
-        self.oidc_enabled = True
-        self.oidc_discover = oidc_config.get("discover", True)
-        self.oidc_issuer = oidc_config["issuer"]
-        self.oidc_client_id = oidc_config["client_id"]
-        self.oidc_client_secret = oidc_config["client_secret"]
-        self.oidc_client_auth_method = oidc_config.get(
-            "client_auth_method", "client_secret_basic"
-        )
-        self.oidc_scopes = oidc_config.get("scopes", ["openid"])
-        self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint")
-        self.oidc_token_endpoint = oidc_config.get("token_endpoint")
-        self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint")
-        self.oidc_jwks_uri = oidc_config.get("jwks_uri")
-        self.oidc_skip_verification = oidc_config.get("skip_verification", False)
-        self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto")
-        self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False)
-
-        ump_config = oidc_config.get("user_mapping_provider", {})
-        ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
-        ump_config.setdefault("config", {})
-
-        (
-            self.oidc_user_mapping_provider_class,
-            self.oidc_user_mapping_provider_config,
-        ) = load_module(ump_config)
-
-        # Ensure loaded user mapping module has defined all necessary methods
-        required_methods = [
-            "get_remote_user_id",
-            "map_user_attributes",
-        ]
-        missing_methods = [
-            method
-            for method in required_methods
-            if not hasattr(self.oidc_user_mapping_provider_class, method)
-        ]
-        if missing_methods:
-            raise ConfigError(
-                "Class specified by oidc_config."
-                "user_mapping_provider.module is missing required "
-                "methods: %s" % (", ".join(missing_methods),)
-            )
+        self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback"
+
+    @property
+    def oidc_enabled(self) -> bool:
+        # OIDC is enabled if we have a provider
+        return bool(self.oidc_providers)
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
-        # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login.
+        # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration
+        # and login.
+        #
+        # Options for each entry include:
+        #
+        #   idp_id: a unique identifier for this identity provider. Used internally
+        #       by Synapse; should be a single word such as 'github'.
+        #
+        #       Note that, if this is changed, users authenticating via that provider
+        #       will no longer be recognised as the same user!
+        #
+        #   idp_name: A user-facing name for this identity provider, which is used to
+        #       offer the user a choice of login mechanisms.
+        #
+        #   idp_icon: An optional icon for this identity provider, which is presented
+        #       by clients and Synapse's own IdP picker page. If given, must be an
+        #       MXC URI of the format mxc://<server-name>/<media-id>. (An easy way to
+        #       obtain such an MXC URI is to upload an image to an (unencrypted) room
+        #       and then copy the "url" from the source of the event.)
+        #
+        #   idp_brand: An optional brand for this identity provider, allowing clients
+        #       to style the login flow according to the identity provider in question.
+        #       See the spec for possible options here.
+        #
+        #   discover: set to 'false' to disable the use of the OIDC discovery mechanism
+        #       to discover endpoints. Defaults to true.
+        #
+        #   issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery
+        #       is enabled) to discover the provider's endpoints.
+        #
+        #   client_id: Required. oauth2 client id to use.
+        #
+        #   client_secret: Required. oauth2 client secret to use.
+        #
+        #   client_auth_method: auth method to use when exchanging the token. Valid
+        #       values are 'client_secret_basic' (default), 'client_secret_post' and
+        #       'none'.
+        #
+        #   scopes: list of scopes to request. This should normally include the "openid"
+        #       scope. Defaults to ["openid"].
+        #
+        #   authorization_endpoint: the oauth2 authorization endpoint. Required if
+        #       provider discovery is disabled.
+        #
+        #   token_endpoint: the oauth2 token endpoint. Required if provider discovery is
+        #       disabled.
+        #
+        #   userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is
+        #       disabled and the 'openid' scope is not requested.
+        #
+        #   jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and
+        #       the 'openid' scope is used.
+        #
+        #   skip_verification: set to 'true' to skip metadata verification. Use this if
+        #       you are connecting to a provider that is not OpenID Connect compliant.
+        #       Defaults to false. Avoid this in production.
+        #
+        #   user_profile_method: Whether to fetch the user profile from the userinfo
+        #       endpoint. Valid values are: 'auto' or 'userinfo_endpoint'.
+        #
+        #       Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is
+        #       included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the
+        #       userinfo endpoint.
+        #
+        #   allow_existing_users: set to 'true' to allow a user logging in via OIDC to
+        #       match a pre-existing account instead of failing. This could be used if
+        #       switching from password logins to OIDC. Defaults to false.
+        #
+        #   user_mapping_provider: Configuration for how attributes returned from a OIDC
+        #       provider are mapped onto a matrix user. This setting has the following
+        #       sub-properties:
+        #
+        #       module: The class name of a custom mapping module. Default is
+        #           {mapping_provider!r}.
+        #           See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
+        #           for information on implementing a custom mapping provider.
+        #
+        #       config: Configuration for the mapping provider module. This section will
+        #           be passed as a Python dictionary to the user mapping provider
+        #           module's `parse_config` method.
+        #
+        #           For the default provider, the following settings are available:
+        #
+        #             subject_claim: name of the claim containing a unique identifier
+        #                 for the user. Defaults to 'sub', which OpenID Connect
+        #                 compliant providers should provide.
+        #
+        #             localpart_template: Jinja2 template for the localpart of the MXID.
+        #                 If this is not set, the user will be prompted to choose their
+        #                 own username (see 'sso_auth_account_details.html' in the 'sso'
+        #                 section of this file).
+        #
+        #             display_name_template: Jinja2 template for the display name to set
+        #                 on first login. If unset, no displayname will be set.
+        #
+        #             email_template: Jinja2 template for the email address of the user.
+        #                 If unset, no email address will be added to the account.
+        #
+        #             extra_attributes: a map of Jinja2 templates for extra attributes
+        #                 to send back to the client during login.
+        #                 Note that these are non-standard and clients will ignore them
+        #                 without modifications.
+        #
+        #           When rendering, the Jinja2 templates are given a 'user' variable,
+        #           which is set to the claims returned by the UserInfo Endpoint and/or
+        #           in the ID Token.
         #
         # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
-        # for some example configurations.
+        # for information on how to configure these options.
         #
-        oidc_config:
-          # Uncomment the following to enable authorization against an OpenID Connect
-          # server. Defaults to false.
-          #
-          #enabled: true
-
-          # Uncomment the following to disable use of the OIDC discovery mechanism to
-          # discover endpoints. Defaults to true.
-          #
-          #discover: false
-
-          # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
-          # discover the provider's endpoints.
-          #
-          # Required if 'enabled' is true.
-          #
-          #issuer: "https://accounts.example.com/"
-
-          # oauth2 client id to use.
-          #
-          # Required if 'enabled' is true.
-          #
-          #client_id: "provided-by-your-issuer"
-
-          # oauth2 client secret to use.
-          #
-          # Required if 'enabled' is true.
-          #
-          #client_secret: "provided-by-your-issuer"
-
-          # auth method to use when exchanging the token.
-          # Valid values are 'client_secret_basic' (default), 'client_secret_post' and
-          # 'none'.
-          #
-          #client_auth_method: client_secret_post
-
-          # list of scopes to request. This should normally include the "openid" scope.
-          # Defaults to ["openid"].
+        # For backwards compatibility, it is also possible to configure a single OIDC
+        # provider via an 'oidc_config' setting. This is now deprecated and admins are
+        # advised to migrate to the 'oidc_providers' format. (When doing that migration,
+        # use 'oidc' for the idp_id to ensure that existing users continue to be
+        # recognised.)
+        #
+        oidc_providers:
+          # Generic example
           #
-          #scopes: ["openid", "profile"]
-
-          # the oauth2 authorization endpoint. Required if provider discovery is disabled.
+          #- idp_id: my_idp
+          #  idp_name: "My OpenID provider"
+          #  idp_icon: "mxc://example.com/mediaid"
+          #  discover: false
+          #  issuer: "https://accounts.example.com/"
+          #  client_id: "provided-by-your-issuer"
+          #  client_secret: "provided-by-your-issuer"
+          #  client_auth_method: client_secret_post
+          #  scopes: ["openid", "profile"]
+          #  authorization_endpoint: "https://accounts.example.com/oauth2/auth"
+          #  token_endpoint: "https://accounts.example.com/oauth2/token"
+          #  userinfo_endpoint: "https://accounts.example.com/userinfo"
+          #  jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+          #  skip_verification: true
+          #  user_mapping_provider:
+          #    config:
+          #      subject_claim: "id"
+          #      localpart_template: "{{ user.login }}"
+          #      display_name_template: "{{ user.name }}"
+          #      email_template: "{{ user.email }}"
+
+          # For use with Keycloak
           #
-          #authorization_endpoint: "https://accounts.example.com/oauth2/auth"
-
-          # the oauth2 token endpoint. Required if provider discovery is disabled.
+          #- idp_id: keycloak
+          #  idp_name: Keycloak
+          #  issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name"
+          #  client_id: "synapse"
+          #  client_secret: "copy secret generated in Keycloak UI"
+          #  scopes: ["openid", "profile"]
+
+          # For use with Github
           #
-          #token_endpoint: "https://accounts.example.com/oauth2/token"
+          #- idp_id: github
+          #  idp_name: Github
+          #  idp_brand: org.matrix.github
+          #  discover: false
+          #  issuer: "https://github.com/"
+          #  client_id: "your-client-id" # TO BE FILLED
+          #  client_secret: "your-client-secret" # TO BE FILLED
+          #  authorization_endpoint: "https://github.com/login/oauth/authorize"
+          #  token_endpoint: "https://github.com/login/oauth/access_token"
+          #  userinfo_endpoint: "https://api.github.com/user"
+          #  scopes: ["read:user"]
+          #  user_mapping_provider:
+          #    config:
+          #      subject_claim: "id"
+          #      localpart_template: "{{ user.login }}"
+          #      display_name_template: "{{ user.name }}"
+        """.format(
+            mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
+        )
 
-          # the OIDC userinfo endpoint. Required if discovery is disabled and the
-          # "openid" scope is not requested.
-          #
-          #userinfo_endpoint: "https://accounts.example.com/userinfo"
 
-          # URI where to fetch the JWKS. Required if discovery is disabled and the
-          # "openid" scope is used.
-          #
-          #jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
+# jsonschema definition of the configuration settings for an oidc identity provider
+OIDC_PROVIDER_CONFIG_SCHEMA = {
+    "type": "object",
+    "required": ["issuer", "client_id", "client_secret"],
+    "properties": {
+        "idp_id": {
+            "type": "string",
+            "minLength": 1,
+            # MSC2858 allows a maxlen of 255, but we prefix with "oidc-"
+            "maxLength": 250,
+            "pattern": "^[A-Za-z0-9._~-]+$",
+        },
+        "idp_name": {"type": "string"},
+        "idp_icon": {"type": "string"},
+        "idp_brand": {
+            "type": "string",
+            # MSC2758-style namespaced identifier
+            "minLength": 1,
+            "maxLength": 255,
+            "pattern": "^[a-z][a-z0-9_.-]*$",
+        },
+        "discover": {"type": "boolean"},
+        "issuer": {"type": "string"},
+        "client_id": {"type": "string"},
+        "client_secret": {"type": "string"},
+        "client_auth_method": {
+            "type": "string",
+            # the following list is the same as the keys of
+            # authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
+            # to avoid importing authlib here.
+            "enum": ["client_secret_basic", "client_secret_post", "none"],
+        },
+        "scopes": {"type": "array", "items": {"type": "string"}},
+        "authorization_endpoint": {"type": "string"},
+        "token_endpoint": {"type": "string"},
+        "userinfo_endpoint": {"type": "string"},
+        "jwks_uri": {"type": "string"},
+        "skip_verification": {"type": "boolean"},
+        "user_profile_method": {
+            "type": "string",
+            "enum": ["auto", "userinfo_endpoint"],
+        },
+        "allow_existing_users": {"type": "boolean"},
+        "user_mapping_provider": {"type": ["object", "null"]},
+    },
+}
+
+# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name
+OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = {
+    "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}]
+}
+
+
+# the `oidc_providers` list can either be None (as it is in the default config), or
+# a list of provider configs, each of which requires an explicit ID and name.
+OIDC_PROVIDER_LIST_SCHEMA = {
+    "oneOf": [
+        {"type": "null"},
+        {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA},
+    ]
+}
+
+# the `oidc_config` setting can either be None (which it used to be in the default
+# config), or an object. If an object, it is ignored unless it has an "enabled: True"
+# property.
+#
+# It's *possible* to represent this with jsonschema, but the resultant errors aren't
+# particularly clear, so we just check for either an object or a null here, and do
+# additional checks in the code.
+OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]}
+
+# the top-level schema can contain an "oidc_config" and/or an "oidc_providers".
+MAIN_CONFIG_SCHEMA = {
+    "type": "object",
+    "properties": {
+        "oidc_config": OIDC_CONFIG_SCHEMA,
+        "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA,
+    },
+}
+
+
+def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]:
+    """extract and parse the OIDC provider configs from the config dict
+
+    The configuration may contain either a single `oidc_config` object with an
+    `enabled: True` property, or a list of provider configurations under
+    `oidc_providers`, *or both*.
+
+    Returns a generator which yields the OidcProviderConfig objects
+    """
+    validate_config(MAIN_CONFIG_SCHEMA, config, ())
+
+    for i, p in enumerate(config.get("oidc_providers") or []):
+        yield _parse_oidc_config_dict(p, ("oidc_providers", "<item %i>" % (i,)))
+
+    # for backwards-compatibility, it is also possible to provide a single "oidc_config"
+    # object with an "enabled: True" property.
+    oidc_config = config.get("oidc_config")
+    if oidc_config and oidc_config.get("enabled", False):
+        # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that
+        # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA
+        # above), so now we need to validate it.
+        validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
+        yield _parse_oidc_config_dict(oidc_config, ("oidc_config",))
+
+
+def _parse_oidc_config_dict(
+    oidc_config: JsonDict, config_path: Tuple[str, ...]
+) -> "OidcProviderConfig":
+    """Take the configuration dict and parse it into an OidcProviderConfig
+
+    Raises:
+        ConfigError if the configuration is malformed.
+    """
+    ump_config = oidc_config.get("user_mapping_provider", {})
+    ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+    ump_config.setdefault("config", {})
+
+    (user_mapping_provider_class, user_mapping_provider_config,) = load_module(
+        ump_config, config_path + ("user_mapping_provider",)
+    )
+
+    # Ensure loaded user mapping module has defined all necessary methods
+    required_methods = [
+        "get_remote_user_id",
+        "map_user_attributes",
+    ]
+    missing_methods = [
+        method
+        for method in required_methods
+        if not hasattr(user_mapping_provider_class, method)
+    ]
+    if missing_methods:
+        raise ConfigError(
+            "Class %s is missing required "
+            "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
+            config_path + ("user_mapping_provider", "module"),
+        )
 
-          # Uncomment to skip metadata verification. Defaults to false.
-          #
-          # Use this if you are connecting to a provider that is not OpenID Connect
-          # compliant.
-          # Avoid this in production.
-          #
-          #skip_verification: true
+    idp_id = oidc_config.get("idp_id", "oidc")
 
-          # Whether to fetch the user profile from the userinfo endpoint. Valid
-          # values are: "auto" or "userinfo_endpoint".
-          #
-          # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included
-          # in `scopes`. Uncomment the following to always fetch the userinfo endpoint.
-          #
-          #user_profile_method: "userinfo_endpoint"
+    # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
+    # clashes with other mechs (such as SAML, CAS).
+    #
+    # We allow "oidc" as an exception so that people migrating from old-style
+    # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to
+    # a new-style "oidc_providers" entry without changing the idp_id for their provider
+    # (and thereby invalidating their user_external_ids data).
 
-          # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead
-          # of failing. This could be used if switching from password logins to OIDC. Defaults to false.
-          #
-          #allow_existing_users: true
+    if idp_id != "oidc":
+        idp_id = "oidc-" + idp_id
 
-          # An external module can be provided here as a custom solution to mapping
-          # attributes returned from a OIDC provider onto a matrix user.
-          #
-          user_mapping_provider:
-            # The custom module's class. Uncomment to use a custom module.
-            # Default is {mapping_provider!r}.
-            #
-            # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers
-            # for information on implementing a custom mapping provider.
-            #
-            #module: mapping_provider.OidcMappingProvider
-
-            # Custom configuration values for the module. This section will be passed as
-            # a Python dictionary to the user mapping provider module's `parse_config`
-            # method.
-            #
-            # The examples below are intended for the default provider: they should be
-            # changed if using a custom provider.
-            #
-            config:
-              # name of the claim containing a unique identifier for the user.
-              # Defaults to `sub`, which OpenID Connect compliant providers should provide.
-              #
-              #subject_claim: "sub"
-
-              # Jinja2 template for the localpart of the MXID.
-              #
-              # When rendering, this template is given the following variables:
-              #   * user: The claims returned by the UserInfo Endpoint and/or in the ID
-              #     Token
-              #
-              # This must be configured if using the default mapping provider.
-              #
-              localpart_template: "{{{{ user.preferred_username }}}}"
-
-              # Jinja2 template for the display name to set on first login.
-              #
-              # If unset, no displayname will be set.
-              #
-              #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
-
-              # Jinja2 templates for extra attributes to send back to the client during
-              # login.
-              #
-              # Note that these are non-standard and clients will ignore them without modifications.
-              #
-              #extra_attributes:
-                #birthdate: "{{{{ user.birthdate }}}}"
-        """.format(
-            mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
-        )
+    # MSC2858 also specifies that the idp_icon must be a valid MXC uri
+    idp_icon = oidc_config.get("idp_icon")
+    if idp_icon is not None:
+        try:
+            parse_and_validate_mxc_uri(idp_icon)
+        except ValueError as e:
+            raise ConfigError(
+                "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
+            ) from e
+
+    return OidcProviderConfig(
+        idp_id=idp_id,
+        idp_name=oidc_config.get("idp_name", "OIDC"),
+        idp_icon=idp_icon,
+        idp_brand=oidc_config.get("idp_brand"),
+        discover=oidc_config.get("discover", True),
+        issuer=oidc_config["issuer"],
+        client_id=oidc_config["client_id"],
+        client_secret=oidc_config["client_secret"],
+        client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
+        scopes=oidc_config.get("scopes", ["openid"]),
+        authorization_endpoint=oidc_config.get("authorization_endpoint"),
+        token_endpoint=oidc_config.get("token_endpoint"),
+        userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
+        jwks_uri=oidc_config.get("jwks_uri"),
+        skip_verification=oidc_config.get("skip_verification", False),
+        user_profile_method=oidc_config.get("user_profile_method", "auto"),
+        allow_existing_users=oidc_config.get("allow_existing_users", False),
+        user_mapping_provider_class=user_mapping_provider_class,
+        user_mapping_provider_config=user_mapping_provider_config,
+    )
+
+
+@attr.s(slots=True, frozen=True)
+class OidcProviderConfig:
+    # a unique identifier for this identity provider. Used in the 'user_external_ids'
+    # table, as well as the query/path parameter used in the login protocol.
+    idp_id = attr.ib(type=str)
+
+    # user-facing name for this identity provider.
+    idp_name = attr.ib(type=str)
+
+    # Optional MXC URI for icon for this IdP.
+    idp_icon = attr.ib(type=Optional[str])
+
+    # Optional brand identifier for this IdP.
+    idp_brand = attr.ib(type=Optional[str])
+
+    # whether the OIDC discovery mechanism is used to discover endpoints
+    discover = attr.ib(type=bool)
+
+    # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
+    # discover the provider's endpoints.
+    issuer = attr.ib(type=str)
+
+    # oauth2 client id to use
+    client_id = attr.ib(type=str)
+
+    # oauth2 client secret to use
+    client_secret = attr.ib(type=str)
+
+    # auth method to use when exchanging the token.
+    # Valid values are 'client_secret_basic', 'client_secret_post' and
+    # 'none'.
+    client_auth_method = attr.ib(type=str)
+
+    # list of scopes to request
+    scopes = attr.ib(type=Collection[str])
+
+    # the oauth2 authorization endpoint. Required if discovery is disabled.
+    authorization_endpoint = attr.ib(type=Optional[str])
+
+    # the oauth2 token endpoint. Required if discovery is disabled.
+    token_endpoint = attr.ib(type=Optional[str])
+
+    # the OIDC userinfo endpoint. Required if discovery is disabled and the
+    # "openid" scope is not requested.
+    userinfo_endpoint = attr.ib(type=Optional[str])
+
+    # URI where to fetch the JWKS. Required if discovery is disabled and the
+    # "openid" scope is used.
+    jwks_uri = attr.ib(type=Optional[str])
+
+    # Whether to skip metadata verification
+    skip_verification = attr.ib(type=bool)
+
+    # Whether to fetch the user profile from the userinfo endpoint. Valid
+    # values are: "auto" or "userinfo_endpoint".
+    user_profile_method = attr.ib(type=str)
+
+    # whether to allow a user logging in via OIDC to match a pre-existing account
+    # instead of failing
+    allow_existing_users = attr.ib(type=bool)
+
+    # the class of the user mapping provider
+    user_mapping_provider_class = attr.ib(type=Type)
+
+    # the config of the user mapping provider
+    user_mapping_provider_config = attr.ib()
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 4fda8ae987..85d07c4f8f 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
             providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
 
         providers.extend(config.get("password_providers") or [])
-        for provider in providers:
+        for i, provider in enumerate(providers):
             mod_name = provider["module"]
 
             # This is for backwards compat when the ldap auth provider resided
@@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
                 mod_name = LDAP_PROVIDER
 
             (provider_class, provider_config) = load_module(
-                {"module": mod_name, "config": provider["config"]}
+                {"module": mod_name, "config": provider["config"]},
+                ("password_providers", "<item %i>" % i),
             )
 
             self.password_providers.append((provider_class, provider_config))
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..def33a60ad 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig:
         defaults={"per_second": 0.17, "burst_count": 3.0},
     ):
         self.per_second = config.get("per_second", defaults["per_second"])
-        self.burst_count = config.get("burst_count", defaults["burst_count"])
+        self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
 
 
 class FederationRateLimitConfig:
@@ -102,6 +102,20 @@ class RatelimitConfig(Config):
             defaults={"per_second": 0.01, "burst_count": 3},
         )
 
+        self.rc_3pid_validation = RateLimitConfig(
+            config.get("rc_3pid_validation") or {},
+            defaults={"per_second": 0.003, "burst_count": 5},
+        )
+
+        self.rc_invites_per_room = RateLimitConfig(
+            config.get("rc_invites", {}).get("per_room", {}),
+            defaults={"per_second": 0.3, "burst_count": 10},
+        )
+        self.rc_invites_per_user = RateLimitConfig(
+            config.get("rc_invites", {}).get("per_user", {}),
+            defaults={"per_second": 0.003, "burst_count": 5},
+        )
+
     def generate_config_section(self, **kwargs):
         return """\
         ## Ratelimiting ##
@@ -131,6 +145,9 @@ class RatelimitConfig(Config):
         #     users are joining rooms the server is already in (this is cheap) vs
         #     "remote" for when users are trying to join rooms not on the server (which
         #     can be more expensive)
+        #   - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
+        #   - two for ratelimiting how often invites can be sent in a room or to a
+        #     specific user.
         #
         # The defaults are as shown below.
         #
@@ -164,7 +181,18 @@ class RatelimitConfig(Config):
         #  remote:
         #    per_second: 0.01
         #    burst_count: 3
-
+        #
+        #rc_3pid_validation:
+        #  per_second: 0.003
+        #  burst_count: 5
+        #
+        #rc_invites:
+        #  per_room:
+        #    per_second: 0.3
+        #    burst_count: 10
+        #  per_user:
+        #    per_second: 0.003
+        #    burst_count: 5
 
         # Ratelimiting settings for incoming federation
         #
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 82be5a35aa..0b4c21b338 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -14,14 +14,13 @@
 # limitations under the License.
 
 import os
-from distutils.util import strtobool
 
 import pkg_resources
 
 from synapse.api.constants import RoomCreationPreset
 from synapse.config._base import Config, ConfigError
 from synapse.types import RoomAlias, UserID
-from synapse.util.stringutils import random_string_with_symbols
+from synapse.util.stringutils import random_string_with_symbols, strtobool
 
 
 class AccountValidityConfig(Config):
@@ -86,12 +85,12 @@ class RegistrationConfig(Config):
     section = "registration"
 
     def read_config(self, config, **kwargs):
-        self.enable_registration = bool(
-            strtobool(str(config.get("enable_registration", False)))
+        self.enable_registration = strtobool(
+            str(config.get("enable_registration", False))
         )
         if "disable_registration" in config:
-            self.enable_registration = not bool(
-                strtobool(str(config["disable_registration"]))
+            self.enable_registration = not strtobool(
+                str(config["disable_registration"])
             )
 
         self.account_validity = AccountValidityConfig(
@@ -191,9 +190,7 @@ class RegistrationConfig(Config):
         self.session_lifetime = session_lifetime
 
         # The success template used during fallback auth.
-        self.fallback_success_template = self.read_templates(
-            ["auth_success.html"], autoescape=True
-        )[0]
+        self.fallback_success_template = self.read_template("auth_success.html")
 
     def generate_config_section(self, generate_secrets=False, **kwargs):
         if generate_secrets:
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index ba1e9d2361..850ac3ebd6 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -17,6 +17,9 @@ import os
 from collections import namedtuple
 from typing import Dict, List
 
+from netaddr import IPSet
+
+from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST
 from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module
 
@@ -142,7 +145,7 @@ class ContentRepositoryConfig(Config):
         # them to be started.
         self.media_storage_providers = []  # type: List[tuple]
 
-        for provider_config in storage_providers:
+        for i, provider_config in enumerate(storage_providers):
             # We special case the module "file_system" so as not to need to
             # expose FileStorageProviderBackend
             if provider_config["module"] == "file_system":
@@ -151,7 +154,9 @@ class ContentRepositoryConfig(Config):
                     ".FileStorageProviderBackend"
                 )
 
-            provider_class, parsed_config = load_module(provider_config)
+            provider_class, parsed_config = load_module(
+                provider_config, ("media_storage_providers", "<item %i>" % i)
+            )
 
             wrapper_config = MediaStorageProviderConfig(
                 provider_config.get("store_local", False),
@@ -182,9 +187,6 @@ class ContentRepositoryConfig(Config):
                     "to work"
                 )
 
-            # netaddr is a dependency for url_preview
-            from netaddr import IPSet
-
             self.url_preview_ip_range_blacklist = IPSet(
                 config["url_preview_ip_range_blacklist"]
             )
@@ -213,6 +215,10 @@ class ContentRepositoryConfig(Config):
         # strip final NL
         formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
 
+        ip_range_blacklist = "\n".join(
+            "        #  - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+        )
+
         return (
             r"""
         ## Media Store ##
@@ -283,15 +289,7 @@ class ContentRepositoryConfig(Config):
         # you uncomment the following list as a starting point.
         #
         #url_preview_ip_range_blacklist:
-        #  - '127.0.0.0/8'
-        #  - '10.0.0.0/8'
-        #  - '172.16.0.0/12'
-        #  - '192.168.0.0/16'
-        #  - '100.64.0.0/10'
-        #  - '169.254.0.0/16'
-        #  - '::1/128'
-        #  - 'fe80::/64'
-        #  - 'fc00::/7'
+%(ip_range_blacklist)s
 
         # List of IP address CIDR ranges that the URL preview spider is allowed
         # to access even if they are specified in url_preview_ip_range_blacklist.
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 92e1b67528..9a3e1c3e7d 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -180,7 +180,7 @@ class _RoomDirectoryRule:
             self._alias_regex = glob_to_regex(alias)
             self._room_id_regex = glob_to_regex(room_id)
         except Exception as e:
-            raise ConfigError("Failed to parse glob into regex: %s", e)
+            raise ConfigError("Failed to parse glob into regex") from e
 
     def matches(self, user_id, room_id, aliases):
         """Tests if this rule matches the given user_id, room_id and aliases.
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c1b8e98ae0..7226abd829 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -125,7 +125,7 @@ class SAML2Config(Config):
         (
             self.saml2_user_mapping_provider_class,
             self.saml2_user_mapping_provider_config,
-        ) = load_module(ump_dict)
+        ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
 
         # Ensure loaded user mapping module has defined all necessary methods
         # Note parse_config() is already checked during the call to load_module
@@ -196,8 +196,8 @@ class SAML2Config(Config):
             optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
         optional_attributes -= required_attributes
 
-        metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
-        response_url = public_baseurl + "_matrix/saml2/authn_response"
+        metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
+        response_url = public_baseurl + "_synapse/client/saml2/authn_response"
         return {
             "entityid": metadata_url,
             "service": {
@@ -235,10 +235,10 @@ class SAML2Config(Config):
         # enable SAML login.
         #
         # Once SAML support is enabled, a metadata file will be exposed at
-        # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
+        # https://<server>:<port>/_synapse/client/saml2/metadata.xml, which you may be able to
         # use to configure your SAML IdP with. Alternatively, you can manually configure
         # the IdP to use an ACS location of
-        # https://<server>:<port>/_matrix/saml2/authn_response.
+        # https://<server>:<port>/_synapse/client/saml2/authn_response.
         #
         saml2_config:
           # `sp_config` is the configuration for the pysaml2 Service Provider.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 85aa49c02d..5d72cf2d82 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -23,9 +23,10 @@ from typing import Any, Dict, Iterable, List, Optional, Set
 
 import attr
 import yaml
+from netaddr import IPSet
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.http.endpoint import parse_and_validate_server_name
+from synapse.util.stringutils import parse_and_validate_server_name
 
 from ._base import Config, ConfigError
 
@@ -39,6 +40,34 @@ logger = logging.Logger(__name__)
 # in the list.
 DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
 
+DEFAULT_IP_RANGE_BLACKLIST = [
+    # Localhost
+    "127.0.0.0/8",
+    # Private networks.
+    "10.0.0.0/8",
+    "172.16.0.0/12",
+    "192.168.0.0/16",
+    # Carrier grade NAT.
+    "100.64.0.0/10",
+    # Address registry.
+    "192.0.0.0/24",
+    # Link-local networks.
+    "169.254.0.0/16",
+    # Testing networks.
+    "198.18.0.0/15",
+    "192.0.2.0/24",
+    "198.51.100.0/24",
+    "203.0.113.0/24",
+    # Multicast.
+    "224.0.0.0/4",
+    # Localhost
+    "::1/128",
+    # Link-local addresses.
+    "fe80::/10",
+    # Unique local addresses.
+    "fc00::/7",
+]
+
 DEFAULT_ROOM_VERSION = "6"
 
 ROOM_COMPLEXITY_TOO_GREAT = (
@@ -256,6 +285,38 @@ class ServerConfig(Config):
         # due to resource constraints
         self.admin_contact = config.get("admin_contact", None)
 
+        ip_range_blacklist = config.get(
+            "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST
+        )
+
+        # Attempt to create an IPSet from the given ranges
+        try:
+            self.ip_range_blacklist = IPSet(ip_range_blacklist)
+        except Exception as e:
+            raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e
+        # Always blacklist 0.0.0.0, ::
+        self.ip_range_blacklist.update(["0.0.0.0", "::"])
+
+        try:
+            self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ()))
+        except Exception as e:
+            raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e
+
+        # The federation_ip_range_blacklist is used for backwards-compatibility
+        # and only applies to federation and identity servers. If it is not given,
+        # default to ip_range_blacklist.
+        federation_ip_range_blacklist = config.get(
+            "federation_ip_range_blacklist", ip_range_blacklist
+        )
+        try:
+            self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
+        except Exception as e:
+            raise ConfigError(
+                "Invalid range(s) provided in federation_ip_range_blacklist."
+            ) from e
+        # Always blacklist 0.0.0.0, ::
+        self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+
         if self.public_baseurl is not None:
             if self.public_baseurl[-1] != "/":
                 self.public_baseurl += "/"
@@ -561,6 +622,10 @@ class ServerConfig(Config):
     def generate_config_section(
         self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
     ):
+        ip_range_blacklist = "\n".join(
+            "        #  - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+        )
+
         _, bind_port = parse_and_validate_server_name(server_name)
         if bind_port is not None:
             unsecure_port = bind_port - 400
@@ -675,11 +740,12 @@ class ServerConfig(Config):
         #
         #web_client_location: https://riot.example.com/
 
-        # The public-facing base URL that clients use to access this HS
-        # (not including _matrix/...). This is the same URL a user would
-        # enter into the 'custom HS URL' field on their client. If you
-        # use synapse with a reverse proxy, this should be the URL to reach
-        # synapse via the proxy.
+        # The public-facing base URL that clients use to access this Homeserver (not
+        # including _matrix/...). This is the same URL a user might enter into the
+        # 'Custom Homeserver URL' field on their client. If you use Synapse with a
+        # reverse proxy, this should be the URL to reach Synapse via the proxy.
+        # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see
+        # 'listeners' below).
         #
         #public_baseurl: https://example.com/
 
@@ -752,6 +818,33 @@ class ServerConfig(Config):
         #
         #enable_search: false
 
+        # Prevent outgoing requests from being sent to the following blacklisted IP address
+        # CIDR ranges. If this option is not specified then it defaults to private IP
+        # address ranges (see the example below).
+        #
+        # The blacklist applies to the outbound requests for federation, identity servers,
+        # push servers, and for checking key validity for third-party invite events.
+        #
+        # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+        # listed here, since they correspond to unroutable addresses.)
+        #
+        # This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
+        #
+        #ip_range_blacklist:
+%(ip_range_blacklist)s
+
+        # List of IP address CIDR ranges that should be allowed for federation,
+        # identity servers, push servers, and for checking key validity for
+        # third-party invite events. This is useful for specifying exceptions to
+        # wide-ranging blacklisted target IP ranges - e.g. for communication with
+        # a push server only visible in your network.
+        #
+        # This whitelist overrides ip_range_blacklist and defaults to an empty
+        # list.
+        #
+        #ip_range_whitelist:
+        #   - '192.168.1.1'
+
         # List of ports that Synapse should listen on, their purpose and their
         # configuration.
         #
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 3d067d29db..3d05abc158 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
             # spam checker, and thus was simply a dictionary with module
             # and config keys. Support this old behaviour by checking
             # to see if the option resolves to a dictionary
-            self.spam_checkers.append(load_module(spam_checkers))
+            self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
         elif isinstance(spam_checkers, list):
-            for spam_checker in spam_checkers:
+            for i, spam_checker in enumerate(spam_checkers):
+                config_path = ("spam_checker", "<item %i>" % i)
                 if not isinstance(spam_checker, dict):
-                    raise ConfigError("spam_checker syntax is incorrect")
+                    raise ConfigError("expected a mapping", config_path)
 
-                self.spam_checkers.append(load_module(spam_checker))
+                self.spam_checkers.append(load_module(spam_checker, config_path))
         else:
             raise ConfigError("spam_checker syntax is incorrect")
 
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 4427676167..19bdfd462b 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -27,24 +27,28 @@ class SSOConfig(Config):
         sso_config = config.get("sso") or {}  # type: Dict[str, Any]
 
         # The sso-specific template_dir
-        template_dir = sso_config.get("template_dir")
+        self.sso_template_dir = sso_config.get("template_dir")
 
         # Read templates from disk
         (
+            self.sso_login_idp_picker_template,
             self.sso_redirect_confirm_template,
             self.sso_auth_confirm_template,
             self.sso_error_template,
             sso_account_deactivated_template,
             sso_auth_success_template,
+            self.sso_auth_bad_user_template,
         ) = self.read_templates(
             [
+                "sso_login_idp_picker.html",
                 "sso_redirect_confirm.html",
                 "sso_auth_confirm.html",
                 "sso_error.html",
                 "sso_account_deactivated.html",
                 "sso_auth_success.html",
+                "sso_auth_bad_user.html",
             ],
-            template_dir,
+            self.sso_template_dir,
         )
 
         # These templates have no placeholders, so render them here
@@ -93,41 +97,135 @@ class SSOConfig(Config):
             #  - https://my.custom.client/
 
             # Directory in which Synapse will try to find the template files below.
-            # If not set, default templates from within the Synapse package will be used.
-            #
-            # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
-            # If you *do* uncomment it, you will need to make sure that all the templates
-            # below are in the directory.
+            # If not set, or the files named below are not found within the template
+            # directory, default templates from within the Synapse package will be used.
             #
             # Synapse will look for the following templates in this directory:
             #
+            # * HTML page to prompt the user to choose an Identity Provider during
+            #   login: 'sso_login_idp_picker.html'.
+            #
+            #   This is only used if multiple SSO Identity Providers are configured.
+            #
+            #   When rendering, this template is given the following variables:
+            #     * redirect_url: the URL that the user will be redirected to after
+            #       login.
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * providers: a list of available Identity Providers. Each element is
+            #       an object with the following attributes:
+            #
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+            #              for the IdP
+            #         * idp_brand: if specified in the IdP config, a textual identifier
+            #              for the brand of the IdP
+            #
+            #   The rendered HTML page should contain a form which submits its results
+            #   back as a GET request, with the following query parameters:
+            #
+            #     * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+            #       to the template)
+            #
+            #     * idp: the 'idp_id' of the chosen IDP.
+            #
+            # * HTML page to prompt new users to enter a userid and confirm other
+            #   details: 'sso_auth_account_details.html'. This is only shown if the
+            #   SSO implementation (with any user_mapping_provider) does not return
+            #   a localpart.
+            #
+            #   When rendering, this template is given the following variables:
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * idp: details of the SSO Identity Provider that the user logged in
+            #       with: an object with the following attributes:
+            #
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+            #              for the IdP
+            #         * idp_brand: if specified in the IdP config, a textual identifier
+            #              for the brand of the IdP
+            #
+            #     * user_attributes: an object containing details about the user that
+            #       we received from the IdP. May have the following attributes:
+            #
+            #         * display_name: the user's display_name
+            #         * emails: a list of email addresses
+            #
+            #   The template should render a form which submits the following fields:
+            #
+            #     * username: the localpart of the user's chosen user id
+            #
+            # * HTML page allowing the user to consent to the server's terms and
+            #   conditions. This is only shown for new users, and only if
+            #   `user_consent.require_at_registration` is set.
+            #
+            #   When rendering, this template is given the following variables:
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * user_id: the user's matrix proposed ID.
+            #
+            #     * user_profile.display_name: the user's proposed display name, if any.
+            #
+            #     * consent_version: the version of the terms that the user will be
+            #       shown
+            #
+            #     * terms_url: a link to the page showing the terms.
+            #
+            #   The template should render a form which submits the following fields:
+            #
+            #     * accepted_version: the version of the terms accepted by the user
+            #       (ie, 'consent_version' from the input variables).
+            #
             # * HTML page for a confirmation step before redirecting back to the client
             #   with the login token: 'sso_redirect_confirm.html'.
             #
-            #   When rendering, this template is given three variables:
-            #     * redirect_url: the URL the user is about to be redirected to. Needs
-            #                     manual escaping (see
-            #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #   When rendering, this template is given the following variables:
+            #
+            #     * redirect_url: the URL the user is about to be redirected to.
             #
             #     * display_url: the same as `redirect_url`, but with the query
             #                    parameters stripped. The intention is to have a
             #                    human-readable URL to show to users, not to use it as
-            #                    the final address to redirect to. Needs manual escaping
-            #                    (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #                    the final address to redirect to.
             #
             #     * server_name: the homeserver's name.
             #
+            #     * new_user: a boolean indicating whether this is the user's first time
+            #          logging in.
+            #
+            #     * user_id: the user's matrix ID.
+            #
+            #     * user_profile.avatar_url: an MXC URI for the user's avatar, if any.
+            #           None if the user has not set an avatar.
+            #
+            #     * user_profile.display_name: the user's display name. None if the user
+            #           has not set a display name.
+            #
             # * HTML page which notifies the user that they are authenticating to confirm
             #   an operation on their account during the user interactive authentication
             #   process: 'sso_auth_confirm.html'.
             #
             #   When rendering, this template is given the following variables:
-            #     * redirect_url: the URL the user is about to be redirected to. Needs
-            #                     manual escaping (see
-            #                     https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #     * redirect_url: the URL the user is about to be redirected to.
             #
             #     * description: the operation which the user is being asked to confirm
             #
+            #     * idp: details of the Identity Provider that we will use to confirm
+            #       the user's identity: an object with the following attributes:
+            #
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #         * idp_icon: if specified in the IdP config, an MXC URI for an icon
+            #              for the IdP
+            #         * idp_brand: if specified in the IdP config, a textual identifier
+            #              for the brand of the IdP
+            #
             # * HTML page shown after a successful user interactive authentication session:
             #   'sso_auth_success.html'.
             #
@@ -136,6 +234,14 @@ class SSOConfig(Config):
             #
             #   This template has no additional variables.
             #
+            # * HTML page shown after a user-interactive authentication session which
+            #   does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+            #
+            #   When rendering, this template is given the following variables:
+            #     * server_name: the homeserver's name.
+            #     * user_id_to_verify: the MXID of the user that we are trying to
+            #       validate.
+            #
             # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
             #   attempts to login: 'sso_account_deactivated.html'.
             #
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index 10a99c792e..c04e1c4e07 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
 
         provider = config.get("third_party_event_rules", None)
         if provider is not None:
-            self.third_party_event_rules = load_module(provider)
+            self.third_party_event_rules = load_module(
+                provider, ("third_party_event_rules",)
+            )
 
     def generate_config_section(self, **kwargs):
         return """\
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 57ab097eba..f10e33f7b8 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -53,6 +53,15 @@ class WriterLocations:
         default=["master"], type=List[str], converter=_instance_to_list_converter
     )
     typing = attr.ib(default="master", type=str)
+    to_device = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
+    account_data = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
+    receipts = attr.ib(
+        default=["master"], type=List[str], converter=_instance_to_list_converter,
+    )
 
 
 class WorkerConfig(Config):
@@ -85,6 +94,9 @@ class WorkerConfig(Config):
         # The port on the main synapse for HTTP replication endpoint
         self.worker_replication_http_port = config.get("worker_replication_http_port")
 
+        # The shared secret used for authentication when connecting to the main synapse.
+        self.worker_replication_secret = config.get("worker_replication_secret", None)
+
         self.worker_name = config.get("worker_name", self.worker_app)
 
         self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@@ -121,7 +133,7 @@ class WorkerConfig(Config):
 
         # Check that the configured writers for events and typing also appears in
         # `instance_map`.
-        for stream in ("events", "typing"):
+        for stream in ("events", "typing", "to_device", "account_data", "receipts"):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
                 if instance != "master" and instance not in self.instance_map:
@@ -130,6 +142,21 @@ class WorkerConfig(Config):
                         % (instance, stream)
                     )
 
+        if len(self.writers.to_device) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `to_device` messages."
+            )
+
+        if len(self.writers.account_data) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `account_data` messages."
+            )
+
+        if len(self.writers.receipts) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `receipts` messages."
+            )
+
         self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
 
         # Whether this worker should run background tasks or not.
@@ -185,6 +212,13 @@ class WorkerConfig(Config):
         # data). If not provided this defaults to the main process.
         #
         #run_background_tasks_on: worker1
+
+        # A shared secret used by the replication APIs to authenticate HTTP requests
+        # from workers.
+        #
+        # By default this is unused and traffic is not authenticated.
+        #
+        #worker_replication_secret: ""
         """
 
     def read_arguments(self, args):
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 57fd426e87..14b21796d9 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -125,19 +125,24 @@ class FederationPolicyForHTTPS:
         self._no_verify_ssl_context = _no_verify_ssl.getContext()
         self._no_verify_ssl_context.set_info_callback(_context_info_cb)
 
-    def get_options(self, host: bytes):
+        self._should_verify = self._config.federation_verify_certificates
+
+        self._federation_certificate_verification_whitelist = (
+            self._config.federation_certificate_verification_whitelist
+        )
 
+    def get_options(self, host: bytes):
         # IPolicyForHTTPS.get_options takes bytes, but we want to compare
         # against the str whitelist. The hostnames in the whitelist are already
         # IDNA-encoded like the hosts will be here.
         ascii_host = host.decode("ascii")
 
         # Check if certificate verification has been enabled
-        should_verify = self._config.federation_verify_certificates
+        should_verify = self._should_verify
 
         # Check if we've disabled certificate verification for this host
-        if should_verify:
-            for regex in self._config.federation_certificate_verification_whitelist:
+        if self._should_verify:
+            for regex in self._federation_certificate_verification_whitelist:
                 if regex.match(ascii_host):
                     should_verify = False
                     break
@@ -227,7 +232,7 @@ class ConnectionVerifier:
 
     # This code is based on twisted.internet.ssl.ClientTLSOptions.
 
-    def __init__(self, hostname: bytes, verify_certs):
+    def __init__(self, hostname: bytes, verify_certs: bool):
         self._verify_certs = verify_certs
 
         _decoded = hostname.decode("ascii")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0422c43fab..8fb116ae18 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -18,7 +18,7 @@
 import collections.abc
 import hashlib
 import logging
-from typing import Dict
+from typing import Any, Callable, Dict, Tuple
 
 from canonicaljson import encode_canonical_json
 from signedjson.sign import sign_json
@@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
+from synapse.events import EventBase
 from synapse.events.utils import prune_event, prune_event_dict
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
+Hasher = Callable[[bytes], "hashlib._Hash"]
 
-def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
+
+def check_event_content_hash(
+    event: EventBase, hash_algorithm: Hasher = hashlib.sha256
+) -> bool:
     """Check whether the hash for this PDU matches the contents"""
     name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
     logger.debug(
@@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
     return message_hash_bytes == expected_hash
 
 
-def compute_content_hash(event_dict, hash_algorithm):
+def compute_content_hash(
+    event_dict: Dict[str, Any], hash_algorithm: Hasher
+) -> Tuple[str, bytes]:
     """Compute the content hash of an event, which is the hash of the
     unredacted event.
 
     Args:
-        event_dict (dict): The unredacted event as a dict
+        event_dict: The unredacted event as a dict
         hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
             to hash the event
 
     Returns:
-        tuple[str, bytes]: A tuple of the name of hash and the hash as raw
-        bytes.
+        A tuple of the name of hash and the hash as raw bytes.
     """
     event_dict = dict(event_dict)
     event_dict.pop("age_ts", None)
@@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
     return hashed.name, hashed.digest()
 
 
-def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+def compute_event_reference_hash(
+    event, hash_algorithm: Hasher = hashlib.sha256
+) -> Tuple[str, bytes]:
     """Computes the event reference hash. This is the hash of the redacted
     event.
 
     Args:
-        event (FrozenEvent)
+        event
         hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
             to hash the event
 
     Returns:
-        tuple[str, bytes]: A tuple of the name of hash and the hash as raw
-        bytes.
+        A tuple of the name of hash and the hash as raw bytes.
     """
     tmp_event = prune_event(event)
     event_dict = tmp_event.get_pdu_json()
@@ -156,7 +163,7 @@ def add_hashes_and_signatures(
     event_dict: JsonDict,
     signature_name: str,
     signing_key: SigningKey,
-):
+) -> None:
     """Add content hash and sign the event
 
     Args:
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c04ad77cf9..902128a23c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,9 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
 import urllib
 from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from signedjson.key import (
@@ -40,6 +42,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.config.key import TrustedKeyServer
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
@@ -47,11 +50,15 @@ from synapse.logging.context import (
     run_in_background,
 )
 from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.metrics import Measure
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -61,16 +68,17 @@ class VerifyJsonRequest:
     A request to verify a JSON object.
 
     Attributes:
-        server_name(str): The name of the server to verify against.
-
-        key_ids(set[str]): The set of key_ids to that could be used to verify the
-            JSON object
+        server_name: The name of the server to verify against.
 
-        json_object(dict): The JSON object to verify.
+        json_object: The JSON object to verify.
 
-        minimum_valid_until_ts (int): time at which we require the signing key to
+        minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
 
+        request_name: The name of the request.
+
+        key_ids: The set of key_ids to that could be used to verify the JSON object
+
         key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
             A deferred (server_name, key_id, verify_key) tuple that resolves when
             a verify key has been fetched. The deferreds' callbacks are run with no
@@ -80,12 +88,12 @@ class VerifyJsonRequest:
             errbacks with an M_UNAUTHORIZED SynapseError.
     """
 
-    server_name = attr.ib()
-    json_object = attr.ib()
-    minimum_valid_until_ts = attr.ib()
-    request_name = attr.ib()
-    key_ids = attr.ib(init=False)
-    key_ready = attr.ib(default=attr.Factory(defer.Deferred))
+    server_name = attr.ib(type=str)
+    json_object = attr.ib(type=JsonDict)
+    minimum_valid_until_ts = attr.ib(type=int)
+    request_name = attr.ib(type=str)
+    key_ids = attr.ib(init=False, type=List[str])
+    key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
 
     def __attrs_post_init__(self):
         self.key_ids = signature_ids(self.json_object, self.server_name)
@@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
 
 
 class Keyring:
-    def __init__(self, hs, key_fetchers=None):
+    def __init__(
+        self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
+    ):
         self.clock = hs.get_clock()
 
         if key_fetchers is None:
@@ -112,22 +122,26 @@ class Keyring:
         # completes.
         #
         # These are regular, logcontext-agnostic Deferreds.
-        self.key_downloads = {}
+        self.key_downloads = {}  # type: Dict[str, defer.Deferred]
 
     def verify_json_for_server(
-        self, server_name, json_object, validity_time, request_name
-    ):
+        self,
+        server_name: str,
+        json_object: JsonDict,
+        validity_time: int,
+        request_name: str,
+    ) -> defer.Deferred:
         """Verify that a JSON object has been signed by a given server
 
         Args:
-            server_name (str): name of the server which must have signed this object
+            server_name: name of the server which must have signed this object
 
-            json_object (dict): object to be checked
+            json_object: object to be checked
 
-            validity_time (int): timestamp at which we require the signing key to
+            validity_time: timestamp at which we require the signing key to
                 be valid. (0 implies we don't care)
 
-            request_name (str): an identifier for this json object (eg, an event id)
+            request_name: an identifier for this json object (eg, an event id)
                 for logging.
 
         Returns:
@@ -138,12 +152,14 @@ class Keyring:
         requests = (req,)
         return make_deferred_yieldable(self._verify_objects(requests)[0])
 
-    def verify_json_objects_for_server(self, server_and_json):
+    def verify_json_objects_for_server(
+        self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+    ) -> List[defer.Deferred]:
         """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
-            server_and_json (iterable[Tuple[str, dict, int, str]):
+            server_and_json:
                 Iterable of (server_name, json_object, validity_time, request_name)
                 tuples.
 
@@ -164,13 +180,14 @@ class Keyring:
             for server_name, json_object, validity_time, request_name in server_and_json
         )
 
-    def _verify_objects(self, verify_requests):
+    def _verify_objects(
+        self, verify_requests: Iterable[VerifyJsonRequest]
+    ) -> List[defer.Deferred]:
         """Does the work of verify_json_[objects_]for_server
 
 
         Args:
-            verify_requests (iterable[VerifyJsonRequest]):
-                Iterable of verification requests.
+            verify_requests: Iterable of verification requests.
 
         Returns:
             List<Deferred[None]>: for each input item, a deferred indicating success
@@ -182,7 +199,7 @@ class Keyring:
         key_lookups = []
         handle = preserve_fn(_handle_key_deferred)
 
-        def process(verify_request):
+        def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
             """Process an entry in the request list
 
             Adds a key request to key_lookups, and returns a deferred which
@@ -222,18 +239,20 @@ class Keyring:
 
         return results
 
-    async def _start_key_lookups(self, verify_requests):
+    async def _start_key_lookups(
+        self, verify_requests: List[VerifyJsonRequest]
+    ) -> None:
         """Sets off the key fetches for each verify request
 
         Once each fetch completes, verify_request.key_ready will be resolved.
 
         Args:
-            verify_requests (List[VerifyJsonRequest]):
+            verify_requests:
         """
 
         try:
             # map from server name to a set of outstanding request ids
-            server_to_request_ids = {}
+            server_to_request_ids = {}  # type: Dict[str, Set[int]]
 
             for verify_request in verify_requests:
                 server_name = verify_request.server_name
@@ -275,11 +294,11 @@ class Keyring:
         except Exception:
             logger.exception("Error starting key lookups")
 
-    async def wait_for_previous_lookups(self, server_names) -> None:
+    async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
         """Waits for any previous key lookups for the given servers to finish.
 
         Args:
-            server_names (Iterable[str]): list of servers which we want to look up
+            server_names: list of servers which we want to look up
 
         Returns:
             Resolves once all key lookups for the given servers have
@@ -304,7 +323,7 @@ class Keyring:
 
             loop_count += 1
 
-    def _get_server_verify_keys(self, verify_requests):
+    def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
         """Tries to find at least one key for each verify request
 
         For each verify_request, verify_request.key_ready is called back with
@@ -312,7 +331,7 @@ class Keyring:
         with a SynapseError if none of the keys are found.
 
         Args:
-            verify_requests (list[VerifyJsonRequest]): list of verify requests
+            verify_requests: list of verify requests
         """
 
         remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@@ -366,17 +385,19 @@ class Keyring:
 
         run_in_background(do_iterations)
 
-    async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+    async def _attempt_key_fetches_with_fetcher(
+        self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
+    ):
         """Use a key fetcher to attempt to satisfy some key requests
 
         Args:
-            fetcher (KeyFetcher): fetcher to use to fetch the keys
-            remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+            fetcher: fetcher to use to fetch the keys
+            remaining_requests: outstanding key requests.
                 Any successfully-completed requests will be removed from the list.
         """
-        # dict[str, dict[str, int]]: keys to fetch.
+        # The keys to fetch.
         # server_name -> key_id -> min_valid_ts
-        missing_keys = defaultdict(dict)
+        missing_keys = defaultdict(dict)  # type: Dict[str, Dict[str, int]]
 
         for verify_request in remaining_requests:
             # any completed requests should already have been removed
@@ -438,16 +459,18 @@ class Keyring:
         remaining_requests.difference_update(completed)
 
 
-class KeyFetcher:
-    async def get_keys(self, keys_to_fetch):
+class KeyFetcher(metaclass=abc.ABCMeta):
+    @abc.abstractmethod
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
-            keys_to_fetch (dict[str, dict[str, int]]):
+            keys_to_fetch:
                 the keys to be fetched. server_name -> key_id -> min_valid_ts
 
         Returns:
-            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
-                map from server_name -> key_id -> FetchKeyResult
+            Map from server_name -> key_id -> FetchKeyResult
         """
         raise NotImplementedError
 
@@ -455,31 +478,35 @@ class KeyFetcher:
 class StoreKeyFetcher(KeyFetcher):
     """KeyFetcher impl which fetches keys from our data store"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def get_keys(self, keys_to_fetch):
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """see KeyFetcher.get_keys"""
 
-        keys_to_fetch = (
+        key_ids_to_fetch = (
             (server_name, key_id)
             for server_name, keys_for_server in keys_to_fetch.items()
             for key_id in keys_for_server.keys()
         )
 
-        res = await self.store.get_server_verify_keys(keys_to_fetch)
-        keys = {}
+        res = await self.store.get_server_verify_keys(key_ids_to_fetch)
+        keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
         for (server_name, key_id), key in res.items():
             keys.setdefault(server_name, {})[key_id] = key
         return keys
 
 
-class BaseV2KeyFetcher:
-    def __init__(self, hs):
+class BaseV2KeyFetcher(KeyFetcher):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.config = hs.get_config()
 
-    async def process_v2_response(self, from_server, response_json, time_added_ms):
+    async def process_v2_response(
+        self, from_server: str, response_json: JsonDict, time_added_ms: int
+    ) -> Dict[str, FetchKeyResult]:
         """Parse a 'Server Keys' structure from the result of a /key request
 
         This is used to parse either the entirety of the response from
@@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
         to /_matrix/key/v2/query.
 
         Args:
-            from_server (str): the name of the server producing this result: either
+            from_server: the name of the server producing this result: either
                 the origin server for a /_matrix/key/v2/server request, or the notary
                 for a /_matrix/key/v2/query.
 
-            response_json (dict): the json-decoded Server Keys response object
+            response_json: the json-decoded Server Keys response object
 
-            time_added_ms (int): the timestamp to record in server_keys_json
+            time_added_ms: the timestamp to record in server_keys_json
 
         Returns:
-            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+            Map from key_id to result object
         """
         ts_valid_until_ms = response_json["valid_until_ts"]
 
@@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
 class PerspectivesKeyFetcher(BaseV2KeyFetcher):
     """KeyFetcher impl which fetches keys from the "perspectives" servers"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
-    async def get_keys(self, keys_to_fetch):
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """see KeyFetcher.get_keys"""
 
-        async def get_key(key_server):
+        async def get_key(key_server: TrustedKeyServer) -> Dict:
             try:
-                result = await self.get_server_verify_key_v2_indirect(
+                return await self.get_server_verify_key_v2_indirect(
                     keys_to_fetch, key_server
                 )
-                return result
             except KeyLookupError as e:
                 logger.warning(
                     "Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             ).addErrback(unwrapFirstError)
         )
 
-        union_of_keys = {}
+        union_of_keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
         for result in results:
             for server_name, keys in result.items():
                 union_of_keys.setdefault(server_name, {}).update(keys)
 
         return union_of_keys
 
-    async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+    async def get_server_verify_key_v2_indirect(
+        self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
-            keys_to_fetch (dict[str, dict[str, int]]):
+            keys_to_fetch:
                 the keys to be fetched. server_name -> key_id -> min_valid_ts
 
-            key_server (synapse.config.key.TrustedKeyServer): notary server to query for
-                the keys
+            key_server: notary server to query for the keys
 
         Returns:
-            dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
-                from server_name -> key_id -> FetchKeyResult
+            Map from server_name -> key_id -> FetchKeyResult
 
         Raises:
             KeyLookupError if there was an error processing the entire response from
@@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         except HttpResponseException as e:
             raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
-        keys = {}
-        added_keys = []
+        keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
+        added_keys = []  # type: List[Tuple[str, str, FetchKeyResult]]
 
         time_now_ms = self.clock.time_msec()
 
+        assert isinstance(query_response, dict)
         for response in query_response["server_keys"]:
             # do this first, so that we can give useful errors thereafter
             server_name = response.get("server_name")
@@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
         return keys
 
-    def _validate_perspectives_response(self, key_server, response):
+    def _validate_perspectives_response(
+        self, key_server: TrustedKeyServer, response: JsonDict
+    ) -> None:
         """Optionally check the signature on the result of a /key/query request
 
         Args:
-            key_server (synapse.config.key.TrustedKeyServer): the notary server that
-                produced this result
+            key_server: the notary server that produced this result
 
-            response (dict): the json-decoded Server Keys response object
+            response: the json-decoded Server Keys response object
         """
         perspective_name = key_server.server_name
         perspective_keys = key_server.verify_keys
@@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 class ServerKeyFetcher(BaseV2KeyFetcher):
     """KeyFetcher impl which fetches keys from the origin servers"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.clock = hs.get_clock()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
 
-    async def get_keys(self, keys_to_fetch):
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
-            keys_to_fetch (dict[str, iterable[str]]):
+            keys_to_fetch:
                 the keys to be fetched. server_name -> key_ids
 
         Returns:
-            dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
-                map from server_name -> key_id -> FetchKeyResult
+            Map from server_name -> key_id -> FetchKeyResult
         """
 
         results = {}
 
-        async def get_key(key_to_fetch_item):
+        async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
             server_name, key_ids = key_to_fetch_item
             try:
                 keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         await yieldable_gather_results(get_key, keys_to_fetch.items())
         return results
 
-    async def get_server_verify_key_v2_direct(self, server_name, key_ids):
+    async def get_server_verify_key_v2_direct(
+        self, server_name: str, key_ids: Iterable[str]
+    ) -> Dict[str, FetchKeyResult]:
         """
 
         Args:
-            server_name (str):
-            key_ids (iterable[str]):
+            server_name:
+            key_ids:
 
         Returns:
-            dict[str, FetchKeyResult]: map from key ID to lookup result
+            Map from key ID to lookup result
 
         Raises:
             KeyLookupError if there was a problem making the lookup
         """
-        keys = {}  # type: dict[str, FetchKeyResult]
+        keys = {}  # type: Dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
             # we may have found this key as a side-effect of asking for another.
@@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except HttpResponseException as e:
                 raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
+            assert isinstance(response, dict)
             if response["server_name"] != server_name:
                 raise KeyLookupError(
                     "Expected a response for server %r not %r"
@@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         return keys
 
 
-async def _handle_key_deferred(verify_request) -> None:
+async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
     """Waits for the key to become available, and then performs a verification
 
     Args:
-        verify_request (VerifyJsonRequest):
+        verify_request:
 
     Raises:
         SynapseError if there was a problem performing the verification
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8028663fa8..3ec4120f85 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -17,7 +17,6 @@
 
 import abc
 import os
-from distutils.util import strtobool
 from typing import Dict, Optional, Tuple, Type
 
 from unpaddedbase64 import encode_base64
@@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
+from synapse.util.stringutils import strtobool
 
 # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
 # bugs where we accidentally share e.g. signature dicts. However, converting a
@@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze
 # NOTE: This is overridden by the configuration by the Synapse worker apps, but
 # for the sake of tests, it is set here while it cannot be configured on the
 # homeserver object itself.
+
 USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
 
 
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 936896656a..e7e3a7b9a4 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 
 import inspect
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import Collection
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import synapse.events
@@ -39,7 +40,9 @@ class SpamChecker:
             else:
                 self.spam_checkers.append(module(config=config))
 
-    def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
+    async def check_event_for_spam(
+        self, event: "synapse.events.EventBase"
+    ) -> Union[bool, str]:
         """Checks if a given event is considered "spammy" by this server.
 
         If the server considers an event spammy, then it will be rejected if
@@ -50,15 +53,16 @@ class SpamChecker:
             event: the event to be checked
 
         Returns:
-            True if the event is spammy.
+            True or a string if the event is spammy. If a string is returned it
+            will be used as the error message returned to the user.
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.check_event_for_spam(event):
+            if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
                 return True
 
         return False
 
-    def user_may_invite(
+    async def user_may_invite(
         self, inviter_userid: str, invitee_userid: str, room_id: str
     ) -> bool:
         """Checks if a given user may send an invite
@@ -75,14 +79,18 @@ class SpamChecker:
         """
         for spam_checker in self.spam_checkers:
             if (
-                spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+                await maybe_awaitable(
+                    spam_checker.user_may_invite(
+                        inviter_userid, invitee_userid, room_id
+                    )
+                )
                 is False
             ):
                 return False
 
         return True
 
-    def user_may_create_room(self, userid: str) -> bool:
+    async def user_may_create_room(self, userid: str) -> bool:
         """Checks if a given user may create a room
 
         If this method returns false, the creation request will be rejected.
@@ -94,12 +102,15 @@ class SpamChecker:
             True if the user may create a room, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_create_room(userid) is False:
+            if (
+                await maybe_awaitable(spam_checker.user_may_create_room(userid))
+                is False
+            ):
                 return False
 
         return True
 
-    def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+    async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
         """Checks if a given user may create a room alias
 
         If this method returns false, the association request will be rejected.
@@ -112,12 +123,17 @@ class SpamChecker:
             True if the user may create a room alias, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+            if (
+                await maybe_awaitable(
+                    spam_checker.user_may_create_room_alias(userid, room_alias)
+                )
+                is False
+            ):
                 return False
 
         return True
 
-    def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+    async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
         """Checks if a given user may publish a room to the directory
 
         If this method returns false, the publish request will be rejected.
@@ -130,12 +146,17 @@ class SpamChecker:
             True if the user may publish the room, otherwise False
         """
         for spam_checker in self.spam_checkers:
-            if spam_checker.user_may_publish_room(userid, room_id) is False:
+            if (
+                await maybe_awaitable(
+                    spam_checker.user_may_publish_room(userid, room_id)
+                )
+                is False
+            ):
                 return False
 
         return True
 
-    def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+    async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
 
         If the server considers a username spammy, then it will not be included in
@@ -157,12 +178,12 @@ class SpamChecker:
             if checker:
                 # Make a copy of the user profile object to ensure the spam checker
                 # cannot modify it.
-                if checker(user_profile.copy()):
+                if await maybe_awaitable(checker(user_profile.copy())):
                     return True
 
         return False
 
-    def check_registration_for_spam(
+    async def check_registration_for_spam(
         self,
         email_threepid: Optional[dict],
         username: Optional[str],
@@ -185,7 +206,9 @@ class SpamChecker:
             # spam checker
             checker = getattr(spam_checker, "check_registration_for_spam", None)
             if checker:
-                behaviour = checker(email_threepid, username, request_info)
+                behaviour = await maybe_awaitable(
+                    checker(email_threepid, username, request_info)
+                )
                 assert isinstance(behaviour, RegistrationBehaviour)
                 if behaviour != RegistrationBehaviour.ALLOW:
                     return behaviour
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 14f7f1156f..9c22e33813 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -79,13 +79,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
         "state_key",
         "depth",
         "prev_events",
-        "prev_state",
         "auth_events",
         "origin",
         "origin_server_ts",
-        "membership",
     ]
 
+    # Room versions from before MSC2176 had additional allowed keys.
+    if not room_version.msc2176_redaction_rules:
+        allowed_keys.extend(["prev_state", "membership"])
+
     event_type = event_dict["type"]
 
     new_content = {}
@@ -98,6 +100,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     if event_type == EventTypes.Member:
         add_fields("membership")
     elif event_type == EventTypes.Create:
+        # MSC2176 rules state that create events cannot be redacted.
+        if room_version.msc2176_redaction_rules:
+            return event_dict
+
         add_fields("creator")
     elif event_type == EventTypes.JoinRules:
         add_fields("join_rule")
@@ -112,10 +118,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
             "kick",
             "redact",
         )
+
+        if room_version.msc2176_redaction_rules:
+            add_fields("invite")
+
     elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
         add_fields("aliases")
     elif event_type == EventTypes.RoomHistoryVisibility:
         add_fields("history_visibility")
+    elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
+        add_fields("redacts")
 
     allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
 
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 38aa47963f..383737520a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -78,6 +78,7 @@ class FederationBase:
 
         ctx = current_context()
 
+        @defer.inlineCallbacks
         def callback(_, pdu: EventBase):
             with PreserveLoggingContext(ctx):
                 if not check_event_content_hash(pdu):
@@ -105,7 +106,11 @@ class FederationBase:
                         )
                     return redacted_event
 
-                if self.spam_checker.check_event_for_spam(pdu):
+                result = yield defer.ensureDeferred(
+                    self.spam_checker.check_event_for_spam(pdu)
+                )
+
+                if result:
                     logger.warning(
                         "Event contains spam, redacting %s: %s",
                         pdu.event_id,
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 302b2f69bc..40e1451201 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,6 +18,7 @@ import copy
 import itertools
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -26,7 +27,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Sequence,
     Tuple,
     TypeVar,
     Union,
@@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
 
 
 class FederationClient(FederationBase):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.pdu_destination_tried = {}
+        self.pdu_destination_tried = {}  # type: Dict[str, Dict[str, int]]
         self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
         self.state = hs.get_state_handler()
         self.transport_layer = hs.get_federation_transport_client()
@@ -116,33 +119,32 @@ class FederationClient(FederationBase):
                 self.pdu_destination_tried[event_id] = destination_dict
 
     @log_function
-    def make_query(
+    async def make_query(
         self,
-        destination,
-        query_type,
-        args,
-        retry_on_dns_fail=False,
-        ignore_backoff=False,
-    ):
+        destination: str,
+        query_type: str,
+        args: dict,
+        retry_on_dns_fail: bool = False,
+        ignore_backoff: bool = False,
+    ) -> JsonDict:
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            query_type (str): Category of the query type; should match the
+            destination: Domain name of the remote homeserver
+            query_type: Category of the query type; should match the
                 handler name used in register_query_handler().
-            args (dict): Mapping of strings to strings containing the details
+            args: Mapping of strings to strings containing the details
                 of the query request.
-            ignore_backoff (bool): true to ignore the historical backoff data
+            ignore_backoff: true to ignore the historical backoff data
                 and try the request anyway.
 
         Returns:
-            a Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels(query_type).inc()
 
-        return self.transport_layer.make_query(
+        return await self.transport_layer.make_query(
             destination,
             query_type,
             args,
@@ -151,42 +153,52 @@ class FederationClient(FederationBase):
         )
 
     @log_function
-    def query_client_keys(self, destination, content, timeout):
+    async def query_client_keys(
+        self, destination: str, content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Query device keys for a device hosted on a remote server.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            content (dict): The query content.
+            destination: Domain name of the remote homeserver
+            content: The query content.
 
         Returns:
-            an Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels("client_device_keys").inc()
-        return self.transport_layer.query_client_keys(destination, content, timeout)
+        return await self.transport_layer.query_client_keys(
+            destination, content, timeout
+        )
 
     @log_function
-    def query_user_devices(self, destination, user_id, timeout=30000):
+    async def query_user_devices(
+        self, destination: str, user_id: str, timeout: int = 30000
+    ) -> JsonDict:
         """Query the device keys for a list of user ids hosted on a remote
         server.
         """
         sent_queries_counter.labels("user_devices").inc()
-        return self.transport_layer.query_user_devices(destination, user_id, timeout)
+        return await self.transport_layer.query_user_devices(
+            destination, user_id, timeout
+        )
 
     @log_function
-    def claim_client_keys(self, destination, content, timeout):
+    async def claim_client_keys(
+        self, destination: str, content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
         Args:
-            destination (str): Domain name of the remote homeserver
-            content (dict): The query content.
+            destination: Domain name of the remote homeserver
+            content: The query content.
 
         Returns:
-            an Awaitable which will eventually yield a JSON object from the
-            response
+            The JSON object from the response
         """
         sent_queries_counter.labels("client_one_time_keys").inc()
-        return self.transport_layer.claim_client_keys(destination, content, timeout)
+        return await self.transport_layer.claim_client_keys(
+            destination, content, timeout
+        )
 
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -195,10 +207,10 @@ class FederationClient(FederationBase):
         given destination server.
 
         Args:
-            dest (str): The remote homeserver to ask.
-            room_id (str): The room_id to backfill.
-            limit (int): The maximum number of events to return.
-            extremities (list): our current backwards extremities, to backfill from
+            dest: The remote homeserver to ask.
+            room_id: The room_id to backfill.
+            limit: The maximum number of events to return.
+            extremities: our current backwards extremities, to backfill from
         """
         logger.debug("backfill extrem=%s", extremities)
 
@@ -370,7 +382,7 @@ class FederationClient(FederationBase):
                 for events that have failed their checks
 
         Returns:
-            Deferred : A list of PDUs that have valid signatures and hashes.
+            A list of PDUs that have valid signatures and hashes.
         """
         deferreds = self._check_sigs_and_hashes(room_version, pdus)
 
@@ -418,7 +430,9 @@ class FederationClient(FederationBase):
         else:
             return [p for p in valid_pdus if p]
 
-    async def get_event_auth(self, destination, room_id, event_id):
+    async def get_event_auth(
+        self, destination: str, room_id: str, event_id: str
+    ) -> List[EventBase]:
         res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
 
         room_version = await self.store.get_room_version(room_id)
@@ -700,18 +714,16 @@ class FederationClient(FederationBase):
 
         return await self._try_destination_list("send_join", destinations, send_request)
 
-    async def _do_send_join(self, destination: str, pdu: EventBase):
+    async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_join_v2(
+            return await self.transport_layer.send_join_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
-
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -769,7 +781,7 @@ class FederationClient(FederationBase):
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_invite_v2(
+            return await self.transport_layer.send_invite_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
@@ -779,7 +791,6 @@ class FederationClient(FederationBase):
                     "invite_room_state": pdu.unsigned.get("invite_room_state", []),
                 },
             )
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -799,7 +810,7 @@ class FederationClient(FederationBase):
                         "User's homeserver does not support this room version",
                         Codes.UNSUPPORTED_ROOM_VERSION,
                     )
-            elif e.code == 403:
+            elif e.code in (403, 429):
                 raise e.to_synapse_error()
             else:
                 raise
@@ -842,18 +853,16 @@ class FederationClient(FederationBase):
             "send_leave", destinations, send_request
         )
 
-    async def _do_send_leave(self, destination, pdu):
+    async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
         time_now = self._clock.time_msec()
 
         try:
-            content = await self.transport_layer.send_leave_v2(
+            return await self.transport_layer.send_leave_v2(
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
                 content=pdu.get_pdu_json(time_now),
             )
-
-            return content
         except HttpResponseException as e:
             if e.code in [400, 404]:
                 err = e.to_synapse_error()
@@ -879,7 +888,7 @@ class FederationClient(FederationBase):
         # content.
         return resp[1]
 
-    def get_public_rooms(
+    async def get_public_rooms(
         self,
         remote_server: str,
         limit: Optional[int] = None,
@@ -887,7 +896,7 @@ class FederationClient(FederationBase):
         search_filter: Optional[Dict] = None,
         include_all_networks: bool = False,
         third_party_instance_id: Optional[str] = None,
-    ):
+    ) -> JsonDict:
         """Get the list of public rooms from a remote homeserver
 
         Args:
@@ -901,8 +910,7 @@ class FederationClient(FederationBase):
                 party instance
 
         Returns:
-            Awaitable[Dict[str, Any]]: The response from the remote server, or None if
-            `remote_server` is the same as the local server_name
+            The response from the remote server.
 
         Raises:
             HttpResponseException: There was an exception returned from the remote server
@@ -910,7 +918,7 @@ class FederationClient(FederationBase):
                 requests over federation
 
         """
-        return self.transport_layer.get_public_rooms(
+        return await self.transport_layer.get_public_rooms(
             remote_server,
             limit,
             since_token,
@@ -923,7 +931,7 @@ class FederationClient(FederationBase):
         self,
         destination: str,
         room_id: str,
-        earliest_events_ids: Sequence[str],
+        earliest_events_ids: Iterable[str],
         latest_events: Iterable[EventBase],
         limit: int,
         min_depth: int,
@@ -974,7 +982,9 @@ class FederationClient(FederationBase):
 
         return signed_events
 
-    async def forward_third_party_invite(self, destinations, room_id, event_dict):
+    async def forward_third_party_invite(
+        self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
+    ) -> None:
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -983,7 +993,7 @@ class FederationClient(FederationBase):
                 await self.transport_layer.exchange_third_party_invite(
                     destination=destination, room_id=room_id, event_dict=event_dict
                 )
-                return None
+                return
             except CodeMessageException:
                 raise
             except Exception as e:
@@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
 
     async def get_room_complexity(
         self, destination: str, room_id: str
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """
         Fetch the complexity of a remote room from another server.
 
@@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
             could not fetch the complexity.
         """
         try:
-            complexity = await self.transport_layer.get_room_complexity(
+            return await self.transport_layer.get_room_complexity(
                 destination=destination, room_id=room_id
             )
-            return complexity
         except CodeMessageException as e:
             # We didn't manage to get it -- probably a 404. We are okay if other
             # servers don't give it to us.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4b6ab470d0..171d25c945 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import random
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -48,7 +49,6 @@ from synapse.events import EventBase
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
-from synapse.http.endpoint import parse_server_name
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
     make_deferred_yieldable,
@@ -65,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_server_name
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -845,7 +846,6 @@ class FederationHandlerRegistry:
 
     def __init__(self, hs: "HomeServer"):
         self.config = hs.config
-        self.http_client = hs.get_simple_http_client()
         self.clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
 
@@ -861,8 +861,10 @@ class FederationHandlerRegistry:
         )  # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
         self.query_handlers = {}  # type: Dict[str, Callable[[dict], Awaitable[None]]]
 
-        # Map from type to instance name that we should route EDU handling to.
-        self._edu_type_to_instance = {}  # type: Dict[str, str]
+        # Map from type to instance names that we should route EDU handling to.
+        # We randomly choose one instance from the list to route to for each new
+        # EDU received.
+        self._edu_type_to_instance = {}  # type: Dict[str, List[str]]
 
     def register_edu_handler(
         self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
@@ -906,7 +908,12 @@ class FederationHandlerRegistry:
     def register_instance_for_edu(self, edu_type: str, instance_name: str):
         """Register that the EDU handler is on a different instance than master.
         """
-        self._edu_type_to_instance[edu_type] = instance_name
+        self._edu_type_to_instance[edu_type] = [instance_name]
+
+    def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+        """Register that the EDU handler is on multiple instances.
+        """
+        self._edu_type_to_instance[edu_type] = instance_names
 
     async def on_edu(self, edu_type: str, origin: str, content: dict):
         if not self.config.use_presence and edu_type == "m.presence":
@@ -925,8 +932,11 @@ class FederationHandlerRegistry:
             return
 
         # Check if we can route it somewhere else that isn't us
-        route_to = self._edu_type_to_instance.get(edu_type, "master")
-        if route_to != self._instance_name:
+        instances = self._edu_type_to_instance.get(edu_type, ["master"])
+        if self._instance_name not in instances:
+            # Pick an instance randomly so that we don't overload one.
+            route_to = random.choice(instances)
+
             try:
                 await self._send_edu(
                     instance_name=route_to,
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
             self._wake_destinations_needing_catchup,
         )
 
+        self._external_cache = hs.get_external_cache()
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -197,22 +199,40 @@ class FederationSender:
                     if not event.internal_metadata.should_proactively_send():
                         return
 
-                    try:
-                        # Get the state from before the event.
-                        # We need to make sure that this is the state from before
-                        # the event and not from after it.
-                        # Otherwise if the last member on a server in a room is
-                        # banned then it won't receive the event because it won't
-                        # be in the room after the ban.
-                        destinations = await self.state.get_hosts_in_room_at_events(
-                            event.room_id, event_ids=event.prev_event_ids()
-                        )
-                    except Exception:
-                        logger.exception(
-                            "Failed to calculate hosts in room for event: %s",
-                            event.event_id,
+                    destinations = None  # type: Optional[Set[str]]
+                    if not event.prev_event_ids():
+                        # If there are no prev event IDs then the state is empty
+                        # and so no remote servers in the room
+                        destinations = set()
+                    else:
+                        # We check the external cache for the destinations, which is
+                        # stored per state group.
+
+                        sg = await self._external_cache.get(
+                            "event_to_prev_state_group", event.event_id
                         )
-                        return
+                        if sg:
+                            destinations = await self._external_cache.get(
+                                "get_joined_hosts", str(sg)
+                            )
+
+                    if destinations is None:
+                        try:
+                            # Get the state from before the event.
+                            # We need to make sure that this is the state from before
+                            # the event and not from after it.
+                            # Otherwise if the last member on a server in a room is
+                            # banned then it won't receive the event because it won't
+                            # be in the room after the ban.
+                            destinations = await self.state.get_hosts_in_room_at_events(
+                                event.room_id, event_ids=event.prev_event_ids()
+                            )
+                        except Exception:
+                            logger.exception(
+                                "Failed to calculate hosts in room for event: %s",
+                                event.event_id,
+                            )
+                            return
 
                     destinations = {
                         d
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..abe9168c78 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -35,7 +35,7 @@ class TransportLayerClient:
 
     def __init__(self, hs):
         self.server_name = hs.hostname
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
 
     @log_function
     def get_room_state_ids(self, destination, room_id, event_id):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20ec..95c64510a9 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -28,7 +28,6 @@ from synapse.api.urls import (
     FEDERATION_V1_PREFIX,
     FEDERATION_V2_PREFIX,
 )
-from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_boolean_from_args,
@@ -45,6 +44,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.server import HomeServer
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
+from synapse.util.stringutils import parse_and_validate_server_name
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger(__name__)
@@ -144,7 +144,7 @@ class Authenticator:
         ):
             raise FederationDeniedError(origin)
 
-        if not json_request["signatures"]:
+        if origin is None or not json_request["signatures"]:
             raise NoAuthenticationError(
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED
             )
@@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
 
     Args:
         hs (synapse.server.HomeServer): homeserver
-        resource (TransportLayerServer): resource class to register to
+        resource (JsonResource): resource class to register to
         authenticator (Authenticator): authenticator to use
         ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
         servlet_groups (list[str], optional): List of servlet groups to register.
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..d29b066a56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
 class BaseHandler:
     """
     Common base class for the event handlers.
+
+    Deprecated: new code should not use this. Instead, Handler classes should define the
+    fields they actually need. The utility methods should either be factored out to
+    standalone helper functions, or to different Handler classes.
     """
 
     def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 341135822e..b1a5df9638 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,14 +13,157 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import random
 from typing import TYPE_CHECKING, List, Tuple
 
+from synapse.replication.http.account_data import (
+    ReplicationAddTagRestServlet,
+    ReplicationRemoveTagRestServlet,
+    ReplicationRoomAccountDataRestServlet,
+    ReplicationUserAccountDataRestServlet,
+)
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
 
+class AccountDataHandler:
+    def __init__(self, hs: "HomeServer"):
+        self._store = hs.get_datastore()
+        self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
+
+        self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
+        self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
+        self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
+        self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
+        self._account_data_writers = hs.config.worker.writers.account_data
+
+    async def add_account_data_to_room(
+        self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_to_room(
+                user_id, room_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+
+            return max_stream_id
+        else:
+            response = await self._room_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_account_data_for_user(
+        self, user_id: str, account_data_type: str, content: JsonDict
+    ) -> int:
+        """Add some account_data to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            account_data_type: The type of account_data to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The maximum stream ID.
+        """
+
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_account_data_for_user(
+                user_id, account_data_type, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._user_data_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                account_data_type=account_data_type,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def add_tag_to_room(
+        self, user_id: str, room_id: str, tag: str, content: JsonDict
+    ) -> int:
+        """Add a tag to a room for a user.
+
+        Args:
+            user_id: The user to add a tag for.
+            room_id: The room to add a tag for.
+            tag: The tag name to add.
+            content: A json object to associate with the tag.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.add_tag_to_room(
+                user_id, room_id, tag, content
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._add_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+                content=content,
+            )
+            return response["max_stream_id"]
+
+    async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
+        """Remove a tag from a room for a user.
+
+        Returns:
+            The next account data ID.
+        """
+        if self._instance_name in self._account_data_writers:
+            max_stream_id = await self._store.remove_tag_from_room(
+                user_id, room_id, tag
+            )
+
+            self._notifier.on_new_event(
+                "account_data_key", max_stream_id, users=[user_id]
+            )
+            return max_stream_id
+        else:
+            response = await self._remove_tag_client(
+                instance_name=random.choice(self._account_data_writers),
+                user_id=user_id,
+                room_id=room_id,
+                tag=tag,
+            )
+            return response["max_stream_id"]
+
+
 class AccountDataEventSource:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 import twisted
 import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
 
 from synapse.app import check_bind_error
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
 
 
 class AcmeHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.reactor = hs.get_reactor()
         self._acme_domain = hs.config.acme_domain
 
-    async def start_listening(self):
+    async def start_listening(self) -> None:
         from synapse.handlers import acme_issuing_service
 
         # Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
             logger.error(ACME_REGISTER_FAIL_ERROR)
             raise
 
-    async def provision_certificate(self):
+    async def provision_certificate(self) -> None:
 
         logger.warning("Reprovisioning %s", self._acme_domain)
 
@@ -110,5 +114,3 @@ class AcmeHandler:
         except Exception:
             logger.exception("Failed saving!")
             raise
-
-        return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
 imported conditionally.
 """
 import logging
+from typing import Dict, Iterable, List
 
 import attr
+import pem
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
 from zope.interface import implementer
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
 from twisted.python.filepath import FilePath
 from twisted.python.url import URL
+from twisted.web.resource import IResource
 
 logger = logging.getLogger(__name__)
 
 
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+    reactor: IReactorTCP,
+    acme_url: str,
+    account_key_file: str,
+    well_known_resource: IResource,
+) -> AcmeIssuingService:
     """Create an ACME issuing service, and attach it to a web Resource
 
     Args:
         reactor: twisted reactor
-        acme_url (str): URL to use to request certificates
-        account_key_file (str): where to store the account key
-        well_known_resource (twisted.web.IResource): web resource for .well-known.
+        acme_url: URL to use to request certificates
+        account_key_file: where to store the account key
+        well_known_resource: web resource for .well-known.
             we will attach a child resource for "acme-challenge".
 
     Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
     A store that only stores in memory.
     """
 
-    certs = attr.ib(default=attr.Factory(dict))
+    certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
 
-    def store(self, server_name, pem_objects):
+    def store(
+        self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+    ) -> defer.Deferred:
         self.certs[server_name] = [o.as_bytes() for o in pem_objects]
         return defer.succeed(None)
 
 
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
     """Load the ACME account key from a file, creating it if it does not exist.
 
     Args:
-        key_file (str): name of the file to use as the account key
+        key_file: name of the file to use as the account key
     """
     # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
     # hardcode the 'client.key' filename
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a703944543..37e63da9b1 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
-from typing import List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
 
 from synapse.api.constants import Membership
-from synapse.events import FrozenEvent
-from synapse.types import RoomStreamToken, StateMap
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AdminHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    async def get_whois(self, user):
+    async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
 
         sessions = await self.store.get_user_ip_and_agents(user)
@@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
 
         return ret
 
-    async def get_user(self, user):
+    async def get_user(self, user: UserID) -> Optional[JsonDict]:
         """Function to get user details"""
         ret = await self.store.get_user_by_id(user.to_string())
         if ret:
@@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
             ret["threepids"] = threepids
         return ret
 
-    async def export_user_data(self, user_id, writer):
+    async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
         """Write all data we have on the user to the given writer.
 
         Args:
-            user_id (str)
-            writer (ExfiltrationWriter)
+            user_id: The user ID to fetch data of.
+            writer: The writer to write to.
 
         Returns:
             Resolves when all data for a user has been written.
@@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
             from_key = RoomStreamToken(0, 0)
             to_key = RoomStreamToken(None, stream_ordering)
 
-            written_events = set()  # Events that we've processed in this room
+            # Events that we've processed in this room
+            written_events = set()  # type: Set[str]
 
             # We need to track gaps in the events stream so that we can then
             # write out the state at those events. We do this by keeping track
@@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
 
             # The reverse mapping to above, i.e. map from unseen event to events
             # that have the unseen event in their prev_events, i.e. the unseen
-            # events "children". dict[str, set[str]]
-            unseen_to_child_events = {}
+            # events "children".
+            unseen_to_child_events = {}  # type: Dict[str, Set[str]]
 
             # We fetch events in the room the user could see by fetching *all*
             # events that we have and then filtering, this isn't the most
@@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
         return writer.finished()
 
 
-class ExfiltrationWriter:
+class ExfiltrationWriter(metaclass=abc.ABCMeta):
     """Interface used to specify how to write exported data.
     """
 
-    def write_events(self, room_id: str, events: List[FrozenEvent]):
+    @abc.abstractmethod
+    def write_events(self, room_id: str, events: List[EventBase]) -> None:
         """Write a batch of events for a room.
         """
-        pass
+        raise NotImplementedError()
 
-    def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+    @abc.abstractmethod
+    def write_state(
+        self, room_id: str, event_id: str, state: StateMap[EventBase]
+    ) -> None:
         """Write the state at the given event in the room.
 
         This only gets called for backward extremities rather than for each
         event.
         """
-        pass
+        raise NotImplementedError()
 
-    def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+    @abc.abstractmethod
+    def write_invite(
+        self, room_id: str, event: EventBase, state: StateMap[dict]
+    ) -> None:
         """Write an invite for the room, with associated invite state.
 
         Args:
-            room_id
-            event
-            state: A subset of the state at the
-                invite, with a subset of the event keys (type, state_key
-                content and sender)
+            room_id: The room ID the invite is for.
+            event: The invite event.
+            state: A subset of the state at the invite, with a subset of the
+                event keys (type, state_key content and sender).
         """
+        raise NotImplementedError()
 
-    def finished(self):
+    @abc.abstractmethod
+    def finished(self) -> Any:
         """Called when all data has successfully been exported and written.
 
         This functions return value is passed to the caller of
         `export_user_data`.
         """
-        pass
+        raise NotImplementedError()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..a19c556437 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import inspect
 import logging
 import time
 import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Dict,
     Iterable,
@@ -36,6 +36,8 @@ import attr
 import bcrypt
 import pymacaroons
 
+from twisted.web.http import Request
+
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     AuthError,
@@ -47,20 +49,25 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+    INTERACTIVE_AUTH_CHECKERS,
+    UIAuthSessionDataConstants,
+)
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
+from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
-from ._base import BaseHandler
-
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
@@ -193,39 +200,27 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
-        self._sso_enabled = (
-            hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
-        )
-
-        # we keep this as a list despite the O(N^2) implication so that we can
-        # keep PASSWORD first and avoid confusing clients which pick the first
-        # type in the list. (NB that the spec doesn't require us to do so and
-        # clients which favour types that they don't understand over those that
-        # they do are technically broken)
+        self._password_localdb_enabled = hs.config.password_localdb_enabled
 
         # start out by assuming PASSWORD is enabled; we will remove it later if not.
-        login_types = []
-        if hs.config.password_localdb_enabled:
-            login_types.append(LoginType.PASSWORD)
+        login_types = set()
+        if self._password_localdb_enabled:
+            login_types.add(LoginType.PASSWORD)
 
         for provider in self.password_providers:
-            if hasattr(provider, "get_supported_login_types"):
-                for t in provider.get_supported_login_types().keys():
-                    if t not in login_types:
-                        login_types.append(t)
+            login_types.update(provider.get_supported_login_types().keys())
 
         if not self._password_enabled:
+            login_types.discard(LoginType.PASSWORD)
+
+        # Some clients just pick the first type in the list. In this case, we want
+        # them to use PASSWORD (rather than token or whatever), so we want to make sure
+        # that comes first, where it's present.
+        self._supported_login_types = []
+        if LoginType.PASSWORD in login_types:
+            self._supported_login_types.append(LoginType.PASSWORD)
             login_types.remove(LoginType.PASSWORD)
-
-        self._supported_login_types = login_types
-
-        # Login types and UI Auth types have a heavy overlap, but are not
-        # necessarily identical. Login types have SSO (and other login types)
-        # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
-        ui_auth_types = login_types.copy()
-        if self._sso_enabled:
-            ui_auth_types.append(LoginType.SSO)
-        self._supported_ui_auth_types = ui_auth_types
+        self._supported_login_types.extend(login_types)
 
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
@@ -235,6 +230,9 @@ class AuthHandler(BaseHandler):
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
         )
 
+        # The number of seconds to keep a UI auth session active.
+        self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout
+
         # Ratelimitier for failed /login attempts
         self._failed_login_attempts_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -266,10 +264,6 @@ class AuthHandler(BaseHandler):
         # authenticating for an operation to occur on their account.
         self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
 
-        # The following template is shown after a successful user interactive
-        # authentication session. It tells the user they can close the window.
-        self._sso_auth_success_template = hs.config.sso_auth_success_template
-
         # The following template is shown during the SSO authentication process if
         # the account is deactivated.
         self._sso_account_deactivated_template = (
@@ -290,9 +284,8 @@ class AuthHandler(BaseHandler):
         requester: Requester,
         request: SynapseRequest,
         request_body: Dict[str, Any],
-        clientip: str,
         description: str,
-    ) -> Tuple[dict, str]:
+    ) -> Tuple[dict, Optional[str]]:
         """
         Checks that the user is who they claim to be, via a UI auth.
 
@@ -307,8 +300,6 @@ class AuthHandler(BaseHandler):
 
             request_body: The body of the request sent by the client
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
@@ -319,7 +310,8 @@ class AuthHandler(BaseHandler):
                 have been given only in a previous call).
 
                 'session_id' is the ID of this session, either passed in by the
-                client or assigned by this call
+                client or assigned by this call. This is None if UI auth was
+                skipped (by re-using a previous validation).
 
         Raises:
             InteractiveAuthIncompleteError if the client has not yet completed
@@ -333,40 +325,90 @@ class AuthHandler(BaseHandler):
 
         """
 
-        user_id = requester.user.to_string()
+        if self._ui_auth_session_timeout:
+            last_validated = await self.store.get_access_token_last_validated(
+                requester.access_token_id
+            )
+            if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout:
+                # Return the input parameters, minus the auth key, which matches
+                # the logic in check_ui_auth.
+                request_body.pop("auth", None)
+                return request_body, None
+
+        requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
 
         # build a list of supported flows
-        flows = [[login_type] for login_type in self._supported_ui_auth_types]
+        supported_ui_auth_types = await self._get_available_ui_auth_types(
+            requester.user
+        )
+        flows = [[login_type] for login_type in supported_ui_auth_types]
+
+        def get_new_session_data() -> JsonDict:
+            return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
 
         try:
             result, params, session_id = await self.check_ui_auth(
-                flows, request, request_body, clientip, description
+                flows, request, request_body, description, get_new_session_data,
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
             raise
 
         # find the completed login type
-        for login_type in self._supported_ui_auth_types:
+        for login_type in supported_ui_auth_types:
             if login_type not in result:
                 continue
 
-            user_id = result[login_type]
+            validated_user_id = result[login_type]
             break
         else:
             # this can't happen
             raise Exception("check_auth returned True but no successful login type")
 
         # check that the UI auth matched the access token
-        if user_id != requester.user.to_string():
+        if validated_user_id != requester_user_id:
             raise AuthError(403, "Invalid auth")
 
+        # Note that the access token has been validated.
+        await self.store.update_access_token_last_validated(requester.access_token_id)
+
         return params, session_id
 
+    async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+        """Get a list of the authentication types this user can use
+        """
+
+        ui_auth_types = set()
+
+        # if the HS supports password auth, and the user has a non-null password, we
+        # support password auth
+        if self._password_localdb_enabled and self._password_enabled:
+            lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+            if lookupres:
+                _, password_hash = lookupres
+                if password_hash:
+                    ui_auth_types.add(LoginType.PASSWORD)
+
+        # also allow auth from password providers
+        for provider in self.password_providers:
+            for t in provider.get_supported_login_types().keys():
+                if t == LoginType.PASSWORD and not self._password_enabled:
+                    continue
+                ui_auth_types.add(t)
+
+        # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+        # from sso to mxid.
+        if await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user.to_string()
+        ):
+            ui_auth_types.add(LoginType.SSO)
+
+        return ui_auth_types
+
     def get_enabled_auth_types(self):
         """Return the enabled user-interactive authentication types
 
@@ -380,8 +422,8 @@ class AuthHandler(BaseHandler):
         flows: List[List[str]],
         request: SynapseRequest,
         clientdict: Dict[str, Any],
-        clientip: str,
         description: str,
+        get_new_session_data: Optional[Callable[[], JsonDict]] = None,
     ) -> Tuple[dict, dict, str]:
         """
         Takes a dictionary sent by the client in the login / registration
@@ -402,11 +444,16 @@ class AuthHandler(BaseHandler):
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
+            get_new_session_data:
+                an optional callback which will be called when starting a new session.
+                it should return data to be stored as part of the session.
+
+                The keys of the returned data should be entries in
+                UIAuthSessionDataConstants.
+
         Returns:
             A tuple of (creds, params, session_id).
 
@@ -423,13 +470,10 @@ class AuthHandler(BaseHandler):
                 all the stages in any of the permitted flows.
         """
 
-        authdict = None
         sid = None  # type: Optional[str]
-        if clientdict and "auth" in clientdict:
-            authdict = clientdict["auth"]
-            del clientdict["auth"]
-            if "session" in authdict:
-                sid = authdict["session"]
+        authdict = clientdict.pop("auth", {})
+        if "session" in authdict:
+            sid = authdict["session"]
 
         # Convert the URI and method to strings.
         uri = request.uri.decode("utf-8")
@@ -437,10 +481,15 @@ class AuthHandler(BaseHandler):
 
         # If there's no session ID, create a new session.
         if not sid:
+            new_session_data = get_new_session_data() if get_new_session_data else {}
+
             session = await self.store.create_ui_auth_session(
                 clientdict, uri, method, description
             )
 
+            for k, v in new_session_data.items():
+                await self.set_session_data(session.session_id, k, v)
+
         else:
             try:
                 session = await self.store.get_ui_auth_session(sid)
@@ -496,7 +545,8 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.get_user_agent("")
+        user_agent = get_request_user_agent(request)
+        clientip = request.getClientIP()
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -518,22 +568,14 @@ class AuthHandler(BaseHandler):
                         session.session_id, login_type, result
                     )
             except LoginError as e:
-                if login_type == LoginType.EMAIL_IDENTITY:
-                    # riot used to have a bug where it would request a new
-                    # validation token (thus sending a new email) each time it
-                    # got a 401 with a 'flows' field.
-                    # (https://github.com/vector-im/vector-web/issues/2447).
-                    #
-                    # Grandfather in the old behaviour for now to avoid
-                    # breaking old riot deployments.
-                    raise
-
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
                 errordict = e.error_dict()
 
         creds = await self.store.get_completed_ui_auth_stages(session.session_id)
         for f in flows:
+            # If all the required credentials have been supplied, the user has
+            # successfully completed the UI auth process!
             if len(set(f) - set(creds)) == 0:
                 # it's very useful to know what args are stored, but this can
                 # include the password in the case of registering, so only log
@@ -599,7 +641,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key to store the data under. An entry from
+                UIAuthSessionDataConstants.
             value: The data to store
         """
         try:
@@ -615,7 +658,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key the data was stored under. An entry from
+                UIAuthSessionDataConstants.
             default: Value to return if the key has not been set
         """
         try:
@@ -709,6 +753,7 @@ class AuthHandler(BaseHandler):
         device_id: Optional[str],
         valid_until_ms: Optional[int],
         puppets_user_id: Optional[str] = None,
+        is_appservice_ghost: bool = False,
     ) -> str:
         """
         Creates a new access token for the user with the given user ID.
@@ -725,6 +770,7 @@ class AuthHandler(BaseHandler):
                we should always have a device ID)
             valid_until_ms: when the token is valid until. None for
                 no expiry.
+            is_appservice_ghost: Whether the user is an application ghost user
         Returns:
               The access token for the user's session.
         Raises:
@@ -745,7 +791,11 @@ class AuthHandler(BaseHandler):
                 "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
             )
 
-        await self.auth.check_auth_blocking(user_id)
+        if (
+            not is_appservice_ghost
+            or self.hs.config.appservice.track_appservice_user_ips
+        ):
+            await self.auth.check_auth_blocking(user_id)
 
         access_token = self.macaroon_gen.generate_access_token(user_id)
         await self.store.add_access_token_to_user(
@@ -831,7 +881,7 @@ class AuthHandler(BaseHandler):
 
     async def validate_login(
         self, login_submission: Dict[str, Any], ratelimit: bool = False,
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Authenticates the user for the /login API
 
         Also used by the user-interactive auth flow to validate auth types which don't
@@ -974,7 +1024,7 @@ class AuthHandler(BaseHandler):
 
     async def _validate_userid_login(
         self, username: str, login_submission: Dict[str, Any],
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Helper for validate_login
 
         Handles login, once we've mapped 3pids onto userids
@@ -1029,7 +1079,7 @@ class AuthHandler(BaseHandler):
             if result:
                 return result
 
-        if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+        if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
             # we've already checked that there is a (valid) password field
@@ -1052,7 +1102,7 @@ class AuthHandler(BaseHandler):
 
     async def check_password_provider_3pid(
         self, medium: str, address: str, password: str
-    ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+    ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
         """Check if a password provider is able to validate a thirdparty login
 
         Args:
@@ -1283,12 +1333,12 @@ class AuthHandler(BaseHandler):
         else:
             return False
 
-    async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+    async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
         """
         Get the HTML for the SSO redirect confirmation page.
 
         Args:
-            redirect_url: The URL to redirect to the SSO provider.
+            request: The incoming HTTP request
             session_id: The user interactive authentication session ID.
 
         Returns:
@@ -1298,38 +1348,48 @@ class AuthHandler(BaseHandler):
             session = await self.store.get_ui_auth_session(session_id)
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-        return self._sso_auth_confirm_template.render(
-            description=session.description, redirect_url=redirect_url,
+
+        user_id_to_verify = await self.get_session_data(
+            session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user_id_to_verify
         )
 
-    async def complete_sso_ui_auth(
-        self, registered_user_id: str, session_id: str, request: SynapseRequest,
-    ):
-        """Having figured out a mxid for this user, complete the HTTP request
+        if not idps:
+            # we checked that the user had some remote identities before offering an SSO
+            # flow, so either it's been deleted or the client has requested SSO despite
+            # it not being offered.
+            raise SynapseError(400, "User has no SSO identities")
 
-        Args:
-            registered_user_id: The registered user ID to complete SSO login for.
-            request: The request to complete.
-            client_redirect_url: The URL to which to redirect the user at the end of the
-                process.
-        """
-        # Mark the stage of the authentication as successful.
-        # Save the user who authenticated with SSO, this will be used to ensure
-        # that the account be modified is also the person who logged in.
-        await self.store.mark_ui_auth_stage_complete(
-            session_id, LoginType.SSO, registered_user_id
+        # for now, just pick one
+        idp_id, sso_auth_provider = next(iter(idps.items()))
+        if len(idps) > 0:
+            logger.warning(
+                "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+                "picking %r",
+                user_id_to_verify,
+                idp_id,
+            )
+
+        redirect_url = await sso_auth_provider.handle_redirect_request(
+            request, None, session_id
         )
 
-        # Render the HTML and return.
-        html = self._sso_auth_success_template
-        respond_with_html(request, 200, html)
+        return self._sso_auth_confirm_template.render(
+            description=session.description,
+            redirect_url=redirect_url,
+            idp=sso_auth_provider,
+        )
 
     async def complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1340,6 +1400,8 @@ class AuthHandler(BaseHandler):
                 process.
             extra_attributes: Extra attributes which will be passed to the client
                 during successful login. Must be JSON serializable.
+            new_user: True if we should use wording appropriate to a user who has just
+                registered.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1348,22 +1410,37 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
+        profile = await self.store.get_profileinfo(
+            UserID.from_string(registered_user_id).localpart
+        )
+
         self._complete_sso_login(
-            registered_user_id, request, client_redirect_url, extra_attributes
+            registered_user_id,
+            request,
+            client_redirect_url,
+            extra_attributes,
+            new_user=new_user,
+            user_profile_data=profile,
         )
 
     def _complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
+        user_profile_data: Optional[ProfileInfo] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+
+        if user_profile_data is None:
+            user_profile_data = ProfileInfo(None, None)
+
         # Store any extra attributes which will be passed in the login response.
         # Note that this is per-user so it may overwrite a previous value, this
         # is considered OK since the newest SSO attributes should be most valid.
@@ -1401,6 +1478,9 @@ class AuthHandler(BaseHandler):
             display_url=redirect_url_no_params,
             redirect_url=redirect_url,
             server_name=self._server_name,
+            new_user=new_user,
+            user_id=registered_user_id,
+            user_profile=user_profile_data,
         )
         respond_with_html(request, 200, html)
 
@@ -1438,8 +1518,8 @@ class AuthHandler(BaseHandler):
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({param_name: param})
+        query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+        query.append((param_name, param))
         url_parts[4] = urllib.parse.urlencode(query)
         return urllib.parse.urlunparse(url_parts)
 
@@ -1609,6 +1689,6 @@ class PasswordProvider:
 
         # This might return an awaitable, if it does block the log out
         # until it completes.
-        result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
-        if inspect.isawaitable(result):
-            await result
+        await maybe_awaitable(
+            g(user_id=user_id, device_id=device_id, access_token=access_token,)
+        )
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f4ea0a9767..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -13,13 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Optional
 from xml.etree import ElementTree as ET
 
+import attr
+
 from twisted.web.client import PartialDownloadError
 
-from synapse.api.errors import Codes, LoginError
+from synapse.api.errors import HttpResponseException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
@@ -29,6 +32,26 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class CasError(Exception):
+    """Used to catch errors when validating the CAS ticket.
+    """
+
+    def __init__(self, error, error_description=None):
+        self.error = error
+        self.error_description = error_description
+
+    def __str__(self):
+        if self.error_description:
+            return "{}: {}".format(self.error, self.error_description)
+        return self.error
+
+
+@attr.s(slots=True, frozen=True)
+class CasResponse:
+    username = attr.ib(type=str)
+    attributes = attr.ib(type=Dict[str, Optional[str]])
+
+
 class CasHandler:
     """
     Utility class for to handle the response from a CAS SSO service.
@@ -40,6 +63,7 @@ class CasHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
+        self._store = hs.get_datastore()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
@@ -50,6 +74,21 @@ class CasHandler:
 
         self._http_client = hs.get_proxied_http_client()
 
+        # identifier for the external_ids table
+        self.idp_id = "cas"
+
+        # user-facing name of this auth provider
+        self.idp_name = "CAS"
+
+        # we do not currently support brands/icons for CAS auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+        self.idp_brand = None
+
+        self._sso_handler = hs.get_sso_handler()
+
+        self._sso_handler.register_identity_provider(self)
+
     def _build_service_param(self, args: Dict[str, str]) -> str:
         """
         Generates a value to use as the "service" parameter when redirecting or
@@ -61,22 +100,24 @@ class CasHandler:
         Returns:
             The URL to use as a "service" parameter.
         """
-        return "%s%s?%s" % (
-            self._cas_service_url,
-            "/_matrix/client/r0/login/cas/ticket",
-            urllib.parse.urlencode(args),
-        )
+        return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
 
     async def _validate_ticket(
         self, ticket: str, service_args: Dict[str, str]
-    ) -> Tuple[str, Optional[str]]:
+    ) -> CasResponse:
         """
-        Validate a CAS ticket with the server, parse the response, and return the user and display name.
+        Validate a CAS ticket with the server, and return the parsed the response.
 
         Args:
             ticket: The CAS ticket from the client.
             service_args: Additional arguments to include in the service URL.
-                Should be the same as those passed to `get_redirect_url`.
+                Should be the same as those passed to `handle_redirect_request`.
+
+        Raises:
+            CasError: If there's an error parsing the CAS response.
+
+        Returns:
+            The parsed CAS response.
         """
         uri = self._cas_server_url + "/proxyValidate"
         args = {
@@ -89,77 +130,91 @@ class CasHandler:
             # Twisted raises this error if the connection is closed,
             # even if that's being used old-http style to signal end-of-data
             body = pde.response
+        except HttpResponseException as e:
+            description = (
+                (
+                    'Authorization server responded with a "{status}" error '
+                    "while exchanging the authorization code."
+                ).format(status=e.code),
+            )
+            raise CasError("server_error", description) from e
 
-        user, attributes = self._parse_cas_response(body)
-        displayname = attributes.pop(self._cas_displayname_attribute, None)
-
-        for required_attribute, required_value in self._cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in attributes:
-                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-        return user, displayname
+        return self._parse_cas_response(body)
 
-    def _parse_cas_response(
-        self, cas_response_body: bytes
-    ) -> Tuple[str, Dict[str, Optional[str]]]:
+    def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
         """
         Retrieve the user and other parameters from the CAS response.
 
         Args:
             cas_response_body: The response from the CAS query.
 
+        Raises:
+            CasError: If there's an error parsing the CAS response.
+
         Returns:
-            A tuple of the user and a mapping of other attributes.
+            The parsed CAS response.
         """
-        user = None
-        attributes = {}
-        try:
-            root = ET.fromstring(cas_response_body)
-            if not root.tag.endswith("serviceResponse"):
-                raise Exception("root of CAS response is not serviceResponse")
-            success = root[0].tag.endswith("authenticationSuccess")
-            for child in root[0]:
-                if child.tag.endswith("user"):
-                    user = child.text
-                if child.tag.endswith("attributes"):
-                    for attribute in child:
-                        # ElementTree library expands the namespace in
-                        # attribute tags to the full URL of the namespace.
-                        # We don't care about namespace here and it will always
-                        # be encased in curly braces, so we remove them.
-                        tag = attribute.tag
-                        if "}" in tag:
-                            tag = tag.split("}")[1]
-                        attributes[tag] = attribute.text
-            if user is None:
-                raise Exception("CAS response does not contain user")
-        except Exception:
-            logger.exception("Error parsing CAS response")
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not success:
-            raise LoginError(
-                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+
+        # Ensure the response is valid.
+        root = ET.fromstring(cas_response_body)
+        if not root.tag.endswith("serviceResponse"):
+            raise CasError(
+                "missing_service_response",
+                "root of CAS response is not serviceResponse",
             )
-        return user, attributes
 
-    def get_redirect_url(self, service_args: Dict[str, str]) -> str:
-        """
-        Generates a URL for the CAS server where the client should be redirected.
+        success = root[0].tag.endswith("authenticationSuccess")
+        if not success:
+            raise CasError("unsucessful_response", "Unsuccessful CAS response")
+
+        # Iterate through the nodes and pull out the user and any extra attributes.
+        user = None
+        attributes = {}
+        for child in root[0]:
+            if child.tag.endswith("user"):
+                user = child.text
+            if child.tag.endswith("attributes"):
+                for attribute in child:
+                    # ElementTree library expands the namespace in
+                    # attribute tags to the full URL of the namespace.
+                    # We don't care about namespace here and it will always
+                    # be encased in curly braces, so we remove them.
+                    tag = attribute.tag
+                    if "}" in tag:
+                        tag = tag.split("}")[1]
+                    attributes[tag] = attribute.text
+
+        # Ensure a user was found.
+        if user is None:
+            raise CasError("no_user", "CAS response does not contain user")
+
+        return CasResponse(user, attributes)
+
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Generates a URL for the CAS server where the client should be redirected.
 
         Args:
-            service_args: Additional arguments to include in the final redirect URL.
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
-            The URL to redirect the client to.
+            URL to redirect to
         """
+
+        if ui_auth_session_id:
+            service_args = {"session": ui_auth_session_id}
+        else:
+            assert client_redirect_url
+            service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
         args = urllib.parse.urlencode(
             {"service": self._build_service_param(service_args)}
         )
@@ -201,59 +256,150 @@ class CasHandler:
             args["redirectUrl"] = client_redirect_url
         if session:
             args["session"] = session
-        username, user_display_name = await self._validate_ticket(ticket, args)
-
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
 
-        # Get the matrix ID from the CAS username.
-        user_id = await self._map_cas_user_to_matrix_user(
-            username, user_display_name, user_agent, ip_address
+        try:
+            cas_response = await self._validate_ticket(ticket, args)
+        except CasError as e:
+            logger.exception("Could not validate ticket")
+            self._sso_handler.render_error(request, e.error, e.error_description, 401)
+            return
+
+        await self._handle_cas_response(
+            request, cas_response, client_redirect_url, session
         )
 
+    async def _handle_cas_response(
+        self,
+        request: SynapseRequest,
+        cas_response: CasResponse,
+        client_redirect_url: Optional[str],
+        session: Optional[str],
+    ) -> None:
+        """Handle a CAS response to a ticket request.
+
+        Assumes that the response has been validated. Maps the user onto an MXID,
+        registering them if necessary, and returns a response to the browser.
+
+        Args:
+            request: the incoming request from the browser. We'll respond to it with an
+                HTML page or a redirect
+
+            cas_response: The parsed CAS response.
+
+            client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+                This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+            session: The session parameter from the `/cas/ticket` HTTP request, if given.
+                This should be the UI Auth session id.
+        """
+
+        # first check if we're doing a UIA
         if session:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, session, request,
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self.idp_id, cas_response.username, session, request,
             )
-        else:
-            # If this not a UI auth request than there must be a redirect URL.
-            assert client_redirect_url
 
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url
-            )
+        # otherwise, we're handling a login request.
+
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
+        for required_attribute, required_value in self._cas_required_attributes.items():
+            # If required attribute was not in CAS Response - Forbidden
+            if required_attribute not in cas_response.attributes:
+                self._sso_handler.render_error(
+                    request,
+                    "unauthorised",
+                    "You are not authorised to log in here.",
+                    401,
+                )
+                return
+
+            # Also need to check value
+            if required_value is not None:
+                actual_value = cas_response.attributes[required_attribute]
+                # If required attribute value does not match expected - Forbidden
+                if required_value != actual_value:
+                    self._sso_handler.render_error(
+                        request,
+                        "unauthorised",
+                        "You are not authorised to log in here.",
+                        401,
+                    )
+                    return
+
+        # Call the mapper to register/login the user
+
+        # If this not a UI auth request than there must be a redirect URL.
+        assert client_redirect_url is not None
 
-    async def _map_cas_user_to_matrix_user(
+        try:
+            await self._complete_cas_login(cas_response, request, client_redirect_url)
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._sso_handler.render_error(request, "mapping_error", str(e))
+
+    async def _complete_cas_login(
         self,
-        remote_user_id: str,
-        display_name: Optional[str],
-        user_agent: str,
-        ip_address: str,
-    ) -> str:
+        cas_response: CasResponse,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ) -> None:
         """
-        Given a CAS username, retrieve the user ID for it and possibly register the user.
+        Given a CAS response, complete the login flow
 
-        Args:
-            remote_user_id: The username from the CAS response.
-            display_name: The display name from the CAS response.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+        Retrieves the remote user ID, registers the user if necessary, and serves
+        a redirect back to the client with a login-token.
 
-        Returns:
-             The user ID associated with this response.
+        Args:
+            cas_response: The parsed CAS response.
+            request: The request to respond to
+            client_redirect_url: The redirect URL passed in by the client.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
         """
+        # Note that CAS does not support a mapping provider, so the logic is hard-coded.
+        localpart = map_username_to_mxid_localpart(cas_response.username)
+
+        async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Map from CAS attributes to user attributes.
+            """
+            # Due to the grandfathering logic matching any previously registered
+            # mxids it isn't expected for there to be any failures.
+            if failures:
+                raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
+
+            display_name = cas_response.attributes.get(
+                self._cas_displayname_attribute, None
+            )
 
-        localpart = map_username_to_mxid_localpart(remote_user_id)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+            return UserAttributes(localpart=localpart, display_name=display_name)
 
-        # If the user does not exist, register it.
-        if not registered_user_id:
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=display_name,
-                user_agent_ips=[(user_agent, ip_address)],
+        async def grandfather_existing_users() -> Optional[str]:
+            # Since CAS did not always use the user_external_ids table, always
+            # to attempt to map to existing users.
+            user_id = UserID(localpart, self._hostname).to_string()
+
+            logger.debug(
+                "Looking for existing account based on mapped %s", user_id,
             )
 
-        return registered_user_id
+            users = await self._store.get_users_by_id_case_insensitive(user_id)
+            if users:
+                registered_user_id = list(users.keys())[0]
+                logger.info("Grandfathering mapping to %s", registered_user_id)
+                return registered_user_id
+
+            return None
+
+        await self._sso_handler.complete_sso_login_request(
+            self.idp_id,
+            cas_response.username,
+            request,
+            client_redirect_url,
+            cas_response_to_user_attributes,
+            grandfather_existing_users,
+        )
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e808142365..c4a3b26a84 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester
+from synapse.types import Requester, UserID, create_requester
 
 from ._base import BaseHandler
 
@@ -38,6 +38,7 @@ class DeactivateAccountHandler(BaseHandler):
         self._device_handler = hs.get_device_handler()
         self._room_member_handler = hs.get_room_member_handler()
         self._identity_handler = hs.get_identity_handler()
+        self._profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
         self._server_name = hs.hostname
 
@@ -52,16 +53,23 @@ class DeactivateAccountHandler(BaseHandler):
         self._account_validity_enabled = hs.config.account_validity.enabled
 
     async def deactivate_account(
-        self, user_id: str, erase_data: bool, id_server: Optional[str] = None
+        self,
+        user_id: str,
+        erase_data: bool,
+        requester: Requester,
+        id_server: Optional[str] = None,
+        by_admin: bool = False,
     ) -> bool:
         """Deactivate a user's account
 
         Args:
             user_id: ID of user to be deactivated
             erase_data: whether to GDPR-erase the user's data
+            requester: The user attempting to make this change.
             id_server: Use the given identity server when unbinding
                 any threepids. If None then will attempt to unbind using the
                 identity server specified when binding (if known).
+            by_admin: Whether this change was made by an administrator.
 
         Returns:
             True if identity server supports removing threepids, otherwise False.
@@ -121,6 +129,12 @@ class DeactivateAccountHandler(BaseHandler):
 
         # Mark the user as erased, if they asked for that
         if erase_data:
+            user = UserID.from_string(user_id)
+            # Remove avatar URL from this user
+            await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+            # Remove displayname from this user
+            await self._profile_handler.set_displayname(user, requester, "", by_admin)
+
             logger.info("Marking %s as erased", user_id)
             await self.store.mark_user_erased(user_id)
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..0863154f7a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
         self._auth_handler = hs.get_auth_handler()
 
     @trace
-    async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+    async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
         Retrieve the given user's devices
 
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
         return devices
 
     @trace
-    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+    async def get_device(self, user_id: str, device_id: str) -> JsonDict:
         """ Retrieve the given device
 
         Args:
@@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -946,8 +946,8 @@ class DeviceListUpdater:
     async def process_cross_signing_key_update(
         self,
         user_id: str,
-        master_key: Optional[Dict[str, Any]],
-        self_signing_key: Optional[Dict[str, Any]],
+        master_key: Optional[JsonDict],
+        self_signing_key: Optional[JsonDict],
     ) -> List[str]:
         """Process the given new master and self-signing key for the given remote user.
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 9cac5a8463..0c7737e09d 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -24,6 +24,7 @@ from synapse.logging.opentracing import (
     set_tag,
     start_active_span,
 )
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
@@ -44,13 +45,37 @@ class DeviceMessageHandler:
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
         self.is_mine = hs.is_mine
-        self.federation = hs.get_federation_sender()
 
-        hs.get_federation_registry().register_edu_handler(
-            "m.direct_to_device", self.on_direct_to_device_edu
-        )
+        # We only need to poke the federation sender explicitly if its on the
+        # same instance. Other federation sender instances will get notified by
+        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+        # in the to-device replication stream.
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
+        # If we can handle the to device EDUs we do so, otherwise we route them
+        # to the appropriate worker.
+        if hs.get_instance_name() in hs.config.worker.writers.to_device:
+            hs.get_federation_registry().register_edu_handler(
+                "m.direct_to_device", self.on_direct_to_device_edu
+            )
+        else:
+            hs.get_federation_registry().register_instances_for_edu(
+                "m.direct_to_device", hs.config.worker.writers.to_device,
+            )
 
-        self._device_list_updater = hs.get_device_handler().device_list_updater
+        # The handler to call when we think a user's device list might be out of
+        # sync. We do all device list resyncing on the master instance, so if
+        # we're on a worker we hit the device resync replication API.
+        if hs.config.worker.worker_app is None:
+            self._user_device_resync = (
+                hs.get_device_handler().device_list_updater.user_device_resync
+            )
+        else:
+            self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
+                hs
+            )
 
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         local_messages = {}
@@ -138,9 +163,7 @@ class DeviceMessageHandler:
             await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
 
             # Immediately attempt a resync in the background
-            run_in_background(
-                self._device_list_updater.user_device_resync, sender_user_id
-            )
+            run_in_background(self._user_device_resync, user_id=sender_user_id)
 
     async def send_device_message(
         self,
@@ -195,7 +218,8 @@ class DeviceMessageHandler:
         )
 
         log_kv({"remote_messages": remote_messages})
-        for destination in remote_messages.keys():
-            # Enqueue a new federation transaction to send the new
-            # device messages to each remote destination.
-            self.federation.send_device_messages(destination)
+        if self.federation_sender:
+            for destination in remote_messages.keys():
+                # Enqueue a new federation transaction to send the new
+                # device messages to each remote destination.
+                self.federation_sender.send_device_messages(destination)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d251..abcf86352d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
                         403, "You must be in the room to create an alias for it"
                     )
 
-            if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+            if not await self.spam_checker.user_may_create_room_alias(
+                user_id, room_alias
+            ):
                 raise AuthError(403, "This user is not permitted to create this alias")
 
             if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        if not self.spam_checker.user_may_publish_room(user_id, room_id):
+        if not await self.spam_checker.user_may_publish_room(user_id, room_id):
             raise AuthError(
                 403, "This user is not permitted to publish rooms to the room list"
             )
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
+    JsonDict,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class E2eKeysHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
@@ -78,7 +82,9 @@ class E2eKeysHandler:
         )
 
     @trace
-    async def query_devices(self, query_body, timeout, from_user_id):
+    async def query_devices(
+        self, query_body: JsonDict, timeout: int, from_user_id: str
+    ) -> JsonDict:
         """ Handle a device key query from a client
 
         {
@@ -98,12 +104,14 @@ class E2eKeysHandler:
         }
 
         Args:
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
         """
 
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Iterable[str]]
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
         set_tag("remote_key_query", remote_queries)
 
         # First get local devices.
-        failures = {}
+        # A map of destination -> failure response.
+        failures = {}  # type: Dict[str, JsonDict]
         results = {}
         if local_query:
             local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
         )
 
         # Now attempt to get any remote devices from our local cache.
-        remote_queries_not_in_cache = {}
+        # A map of destination -> user ID -> device IDs.
+        remote_queries_not_in_cache = {}  # type: Dict[str, Dict[str, Iterable[str]]]
         if remote_queries:
-            query_list = []
+            query_list = []  # type: List[Tuple[str, Optional[str]]]
             for user_id, device_ids in remote_queries.items():
                 if device_ids:
                     query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
         return ret
 
     async def get_cross_signing_keys_from_cache(
-        self, query, from_user_id
+        self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Get cross-signing keys for users from the database
 
         Args:
-            query (Iterable[string]) an iterable of user IDs.  A dict whose keys
+            query: an iterable of user IDs.  A dict whose keys
                 are user IDs satisfies this, so the query format used for
                 query_devices can be used here.
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
 
@@ -315,14 +325,12 @@ class E2eKeysHandler:
             if "self_signing" in user_info:
                 self_signing_keys[user_id] = user_info["self_signing"]
 
-        if (
-            from_user_id in keys
-            and keys[from_user_id] is not None
-            and "user_signing" in keys[from_user_id]
-        ):
-            # users can see other users' master and self-signing keys, but can
-            # only see their own user-signing keys
-            user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+        # users can see other users' master and self-signing keys, but can
+        # only see their own user-signing keys
+        if from_user_id:
+            from_user_key = keys.get(from_user_id)
+            if from_user_key and "user_signing" in from_user_key:
+                user_signing_keys[from_user_id] = from_user_key["user_signing"]
 
         return {
             "master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []
+        local_query = []  # type: List[Tuple[str, Optional[str]]]
 
-        result_dict = {}
+        result_dict = {}  # type: Dict[str, Dict[str, dict]]
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,14 @@ class E2eKeysHandler:
         log_kv(results)
         return result_dict
 
-    async def on_federation_query_client_keys(self, query_body):
+    async def on_federation_query_client_keys(
+        self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+    ) -> JsonDict:
         """ Handle a device key query from a federated server
         """
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Optional[List[str]]]
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -397,31 +409,34 @@ class E2eKeysHandler:
         return ret
 
     @trace
-    async def claim_one_time_keys(self, query, timeout):
-        local_query = []
-        remote_queries = {}
+    async def claim_one_time_keys(
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+    ) -> JsonDict:
+        local_query = []  # type: List[Tuple[str, str, str]]
+        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
 
-        for user_id, device_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in device_keys.items():
+                for device_id, algorithm in one_time_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
                 domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_keys
+                remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
         set_tag("local_key_query", local_query)
         set_tag("remote_key_query", remote_queries)
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
-        json_result = {}
-        failures = {}
+        # A map of user ID -> device ID -> key ID -> key.
+        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+        failures = {}  # type: Dict[str, JsonDict]
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
-                for key_id, json_bytes in keys.items():
+                for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_bytes)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         @trace
@@ -468,7 +483,9 @@ class E2eKeysHandler:
         return {"one_time_keys": json_result, "failures": failures}
 
     @tag_args
-    async def upload_keys_for_user(self, user_id, device_id, keys):
+    async def upload_keys_for_user(
+        self, user_id: str, device_id: str, keys: JsonDict
+    ) -> JsonDict:
 
         time_now = self.clock.time_msec()
 
@@ -543,8 +560,8 @@ class E2eKeysHandler:
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
-        self, user_id, device_id, time_now, one_time_keys
-    ):
+        self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+    ) -> None:
         logger.info(
             "Adding one_time_keys %r for device %r for user %r at %d",
             one_time_keys.keys(),
@@ -585,12 +602,14 @@ class E2eKeysHandler:
         log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
         await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
 
-    async def upload_signing_keys_for_user(self, user_id, keys):
+    async def upload_signing_keys_for_user(
+        self, user_id: str, keys: JsonDict
+    ) -> JsonDict:
         """Upload signing keys for cross-signing
 
         Args:
-            user_id (string): the user uploading the keys
-            keys (dict[string, dict]): the signing keys
+            user_id: the user uploading the keys
+            keys: the signing keys
         """
 
         # if a master key is uploaded, then check it.  Otherwise, load the
@@ -667,16 +686,17 @@ class E2eKeysHandler:
 
         return {}
 
-    async def upload_signatures_for_device_keys(self, user_id, signatures):
+    async def upload_signatures_for_device_keys(
+        self, user_id: str, signatures: JsonDict
+    ) -> JsonDict:
         """Upload device signatures for cross-signing
 
         Args:
-            user_id (string): the user uploading the signatures
-            signatures (dict[string, dict[string, dict]]): map of users to
-                devices to signed keys. This is the submission from the user; an
-                exception will be raised if it is malformed.
+            user_id: the user uploading the signatures
+            signatures: map of users to devices to signed keys. This is the submission
+            from the user; an exception will be raised if it is malformed.
         Returns:
-            dict: response to be sent back to the client.  The response will have
+            The response to be sent back to the client.  The response will have
                 a "failures" key, which will be a dict mapping users to devices
                 to errors for the signatures that failed.
         Raises:
@@ -719,7 +739,9 @@ class E2eKeysHandler:
 
         return {"failures": failures}
 
-    async def _process_self_signatures(self, user_id, signatures):
+    async def _process_self_signatures(
+        self, user_id: str, signatures: JsonDict
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of the user's own keys.
 
         Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +753,14 @@ class E2eKeysHandler:
             signatures (dict[string, dict]): map of devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
-            reasons
+            A tuple of a list of signatures to store, and a map of users to
+            devices to failure reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -834,19 +855,24 @@ class E2eKeysHandler:
         return signature_list, failures
 
     def _check_master_key_signature(
-        self, user_id, master_key_id, signed_master_key, stored_master_key, devices
-    ):
+        self,
+        user_id: str,
+        master_key_id: str,
+        signed_master_key: JsonDict,
+        stored_master_key: JsonDict,
+        devices: Dict[str, Dict[str, JsonDict]],
+    ) -> List["SignatureListItem"]:
         """Check signatures of a user's master key made by their devices.
 
         Args:
-            user_id (string): the user whose master key is being checked
-            master_key_id (string): the ID of the user's master key
-            signed_master_key (dict): the user's signed master key that was uploaded
-            stored_master_key (dict): our previously-stored copy of the user's master key
-            devices (iterable(dict)): the user's devices
+            user_id: the user whose master key is being checked
+            master_key_id: the ID of the user's master key
+            signed_master_key: the user's signed master key that was uploaded
+            stored_master_key: our previously-stored copy of the user's master key
+            devices: the user's devices
 
         Returns:
-            list[SignatureListItem]: a list of signatures to store
+            A list of signatures to store
 
         Raises:
             SynapseError: if a signature is invalid
@@ -877,25 +903,26 @@ class E2eKeysHandler:
 
         return master_key_signature_list
 
-    async def _process_other_signatures(self, user_id, signatures):
+    async def _process_other_signatures(
+        self, user_id: str, signatures: Dict[str, dict]
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of other users' keys.  These will be the
         target user's master keys, signed by the uploading user's user-signing
         key.
 
         Args:
-            user_id (string): the user uploading the keys
-            signatures (dict[string, dict]): map of users to devices to signed keys
+            user_id: the user uploading the keys
+            signatures: map of users to devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
+            A list of signatures to store, and a map of users to devices to failure
             reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -983,7 +1010,7 @@ class E2eKeysHandler:
 
     async def _get_e2e_cross_signing_verify_key(
         self, user_id: str, key_type: str, from_user_id: str = None
-    ):
+    ) -> Tuple[JsonDict, str, VerifyKey]:
         """Fetch locally or remotely query for a cross-signing public key.
 
         First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1024,7 @@ class E2eKeysHandler:
                 This affects what signatures are fetched.
 
         Returns:
-            dict, str, VerifyKey: the raw key data, the key ID, and the
-                signedjson verify key
+            The raw key data, the key ID, and the signedjson verify key
 
         Raises:
             NotFoundError: if the key is not found
@@ -1135,16 +1161,18 @@ class E2eKeysHandler:
         return desired_key, desired_key_id, desired_verify_key
 
 
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+    key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity
     checking, and ensures that it is signed, if a signature is required.
 
     Args:
-        key (dict): the key data to verify
-        user_id (str): the user whose key is being checked
-        key_type (str): the type of key that the key should be
-        signing_key (VerifyKey): (optional) the signing key that the key should
-            be signed with.  If omitted, signatures will not be checked.
+        key: the key data to verify
+        user_id: the user whose key is being checked
+        key_type: the type of key that the key should be
+        signing_key: the signing key that the key should be signed with.  If
+            omitted, signatures will not be checked.
     """
     if (
         key.get("user_id") != user_id
@@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
             )
 
 
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+    user_id: str,
+    verify_key: VerifyKey,
+    signed_device: JsonDict,
+    stored_device: JsonDict,
+) -> None:
     """Check that a signature on a device or cross-signing key is correct and
     matches the copy of the device/key that we have stored.  Throws an
     exception if an error is detected.
 
     Args:
-        user_id (str): the user ID whose signature is being checked
-        verify_key (VerifyKey): the key to verify the device with
-        signed_device (dict): the uploaded signed device data
-        stored_device (dict): our previously stored copy of the device
+        user_id: the user ID whose signature is being checked
+        verify_key: the key to verify the device with
+        signed_device: the uploaded signed device data
+        stored_device: our previously stored copy of the device
 
     Raises:
         SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
         raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
 
 
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
     if isinstance(e, SynapseError):
         return {"status": e.code, "errcode": e.errcode, "message": str(e)}
 
@@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
     return {"status": 503, "message": str(e)}
 
 
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
     old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
@@ -1239,16 +1272,16 @@ class SignatureListItem:
     """An item in the signature list as used by upload_signatures_for_device_keys.
     """
 
-    signing_key_id = attr.ib()
-    target_user_id = attr.ib()
-    target_device_id = attr.ib()
-    signature = attr.ib()
+    signing_key_id = attr.ib(type=str)
+    target_user_id = attr.ib(type=str)
+    target_device_id = attr.ib(type=str)
+    signature = attr.ib(type=JsonDict)
 
 
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs, e2e_keys_handler):
+    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}
+        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
             iterable=True,
         )
 
-    async def incoming_signing_key_update(self, origin, edu_content):
+    async def incoming_signing_key_update(
+        self, origin: str, edu_content: JsonDict
+    ) -> None:
         """Called on incoming signing key update from federation. Responsible for
         parsing the EDU and adding to pending updates list.
 
         Args:
-            origin (string): the server that sent the EDU
-            edu_content (dict): the contents of the EDU
+            origin: the server that sent the EDU
+            edu_content: the contents of the EDU
         """
 
         user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
 
         await self._handle_signing_key_updates(user_id)
 
-    async def _handle_signing_key_updates(self, user_id):
+    async def _handle_signing_key_updates(self, user_id: str) -> None:
         """Actually handle pending updates.
 
         Args:
-            user_id (string): the user whose updates we are processing
+            user_id: the user whose updates we are processing
         """
 
         device_handler = self.e2e_keys_handler.device_handler
@@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []
+            device_ids = []  # type: List[str]
 
             logger.info("pending updates: %r", pending_updates)
 
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, List, Optional
 
 from synapse.api.errors import (
     Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
 from synapse.util.async_helpers import Linearizer
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
     The actual payload of the encrypted keys is completely opaque to the handler.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
         # Used to lock whenever a client is uploading key data.  This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
         self._upload_linearizer = Linearizer("upload_room_keys_lock")
 
     @trace
-    async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> List[JsonDict]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
         See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose keys we're getting
-            version(str): the version ID of the backup we're getting keys from
-            room_id(string): room ID to get keys for, for None to get keys for all rooms
-            session_id(string): session ID to get keys for, for None to get keys for all
+            user_id: the user whose keys we're getting
+            version: the version ID of the backup we're getting keys from
+            room_id: room ID to get keys for, for None to get keys for all rooms
+            session_id: session ID to get keys for, for None to get keys for all
                 sessions
         Raises:
             NotFoundError: if the backup version does not exist
         Returns:
-            A deferred list of dicts giving the session_data and message metadata for
+            A list of dicts giving the session_data and message metadata for
             these room keys.
         """
 
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
             return results
 
     @trace
-    async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def delete_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> JsonDict:
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
         See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose backup we're deleting
-            version(str): the version ID of the backup we're deleting
-            room_id(string): room ID to delete keys for, for None to delete keys for all
+            user_id: the user whose backup we're deleting
+            version: the version ID of the backup we're deleting
+            room_id: room ID to delete keys for, for None to delete keys for all
                 rooms
-            session_id(string): session ID to delete keys for, for None to delete keys
+            session_id: session ID to delete keys for, for None to delete keys
                 for all sessions
         Raises:
             NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @trace
-    async def upload_room_keys(self, user_id, version, room_keys):
+    async def upload_room_keys(
+        self, user_id: str, version: str, room_keys: JsonDict
+    ) -> JsonDict:
         """Bulk upload a list of room keys into a given backup version, asserting
         that the given version is the current backup version.  room_keys are merged
         into the current backup as described in RoomKeysServlet.on_PUT().
 
         Args:
-            user_id(str): the user whose backup we're setting
-            version(str): the version ID of the backup we're updating
-            room_keys(dict): a nested dict describing the room_keys we're setting:
+            user_id: the user whose backup we're setting
+            version: the version ID of the backup we're updating
+            room_keys: a nested dict describing the room_keys we're setting:
 
         {
             "rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @staticmethod
-    def _should_replace_room_key(current_room_key, room_key):
+    def _should_replace_room_key(
+        current_room_key: Optional[JsonDict], room_key: JsonDict
+    ) -> bool:
         """
         Determine whether to replace a given current_room_key (if any)
         with a newly uploaded room_key backup
 
         Args:
-            current_room_key (dict): Optional, the current room_key dict if any
-            room_key (dict): The new room_key dict which may or may not be fit to
+            current_room_key: Optional, the current room_key dict if any
+            room_key : The new room_key dict which may or may not be fit to
                 replace the current_room_key
 
         Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
         return True
 
     @trace
-    async def create_version(self, user_id, version_info):
+    async def create_version(self, user_id: str, version_info: JsonDict) -> str:
         """Create a new backup version.  This automatically becomes the new
         backup version for the user's keys; previous backups will no longer be
         writeable to.
 
         Args:
-            user_id(str): the user whose backup version we're creating
-            version_info(dict): metadata about the new version being created
+            user_id: the user whose backup version we're creating
+            version_info: metadata about the new version being created
 
         {
             "algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
         }
 
         Returns:
-            A deferred of a string that gives the new version number.
+            The new version number.
         """
 
         # TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
             )
             return new_version
 
-    async def get_version_info(self, user_id, version=None):
+    async def get_version_info(
+        self, user_id: str, version: Optional[str] = None
+    ) -> JsonDict:
         """Get the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're querying
-            version(str): Optional; if None gives the most recent version
+            user_id: the user whose current backup version we're querying
+            version: Optional; if None gives the most recent version
                 otherwise a historical one.
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of a info dict that gives the info about the new version.
+            A info dict that gives the info about the new version.
 
         {
             "version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
             return res
 
     @trace
-    async def delete_version(self, user_id, version=None):
+    async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
         """Deletes a given version of the user's e2e_room_keys backup
 
         Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
                     raise
 
     @trace
-    async def update_version(self, user_id, version, version_info):
+    async def update_version(
+        self, user_id: str, version: str, version_info: JsonDict
+    ) -> JsonDict:
         """Update the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're updating
-            version(str): the backup version we're updating
-            version_info(dict): the new information about the backup
+            user_id: the user whose current backup version we're updating
+            version: the backup version we're updating
+            version_info: the new information about the backup
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of an empty dict.
+            An empty dict.
         """
         if "version" not in version_info:
             version_info["version"] = version
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..eddc7582d0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
         self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
-        self.http_client = hs.get_simple_http_client()
+        self.http_client = hs.get_proxied_blacklisted_http_client()
         self._instance_name = hs.get_instance_name()
         self._replication = hs.get_replication_data_handler()
 
@@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
         if self.hs.config.block_non_admin_invites:
             raise SynapseError(403, "This server does not accept room invites")
 
-        if not self.spam_checker.user_may_invite(
+        if not await self.spam_checker.user_may_invite(
             event.sender, event.state_key, event.room_id
         ):
             raise SynapseError(
@@ -1617,6 +1617,12 @@ class FederationHandler(BaseHandler):
         if event.state_key == self._server_notices_mxid:
             raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
+        # We retrieve the room member handler here as to not cause a cyclic dependency
+        member_handler = self.hs.get_room_member_handler()
+        # We don't rate limit based on room ID, as that should be done by
+        # sending server.
+        member_handler.ratelimit_invite(None, event.state_key)
+
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
         # join dance).
@@ -2093,6 +2099,11 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.GuestAccess and not context.rejected:
             await self.maybe_kick_guest_users(event)
 
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
         return context
 
     async def _check_for_soft_fail(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7c06cc529e..754a22bd5f 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -29,7 +33,7 @@ def _create_rerouter(func_name):
 
     async def f(self, group_id, *args, **kwargs):
         if not GroupID.is_valid(group_id):
-            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+            raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
 
         if self.is_mine_id(group_id):
             return await getattr(self.groups_server_handler, func_name)(
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
 
 
 class GroupsLocalWorkerHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
     get_group_role = _create_rerouter("get_group_role")
     get_group_roles = _create_rerouter("get_group_roles")
 
-    async def get_group_summary(self, group_id, requester_user_id):
+    async def get_group_summary(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the group summary for a group.
 
         If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_users_in_group(self, group_id, requester_user_id):
+    async def get_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get users in a group
         """
         if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.get_users_in_group(
+            return await self.groups_server_handler.get_users_in_group(
                 group_id, requester_user_id
             )
-            return res
 
         group_server_name = get_domain_from_id(group_id)
 
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_joined_groups(self, user_id):
+    async def get_joined_groups(self, user_id: str) -> JsonDict:
         group_ids = await self.store.get_joined_groups(user_id)
         return {"groups": group_ids}
 
-    async def get_publicised_groups_for_user(self, user_id):
+    async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
         if self.hs.is_mine_id(user_id):
             result = await self.store.get_publicised_groups_for_user(user_id)
 
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
             # TODO: Verify attestations
             return {"groups": result}
 
-    async def bulk_get_publicised_groups(self, user_ids, proxy=True):
-        destinations = {}
+    async def bulk_get_publicised_groups(
+        self, user_ids: Iterable[str], proxy: bool = True
+    ) -> JsonDict:
+        destinations = {}  # type: Dict[str, Set[str]]
         local_users = set()
 
         for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
             raise SynapseError(400, "Some user_ids are not local")
 
         results = {}
-        failed_results = []
+        failed_results = []  # type: List[str]
         for destination, dest_user_ids in destinations.items():
             try:
                 r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
 
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
     set_group_join_policy = _create_rerouter("set_group_join_policy")
 
-    async def create_group(self, group_id, user_id, content):
+    async def create_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Create a group
         """
 
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
             local_attestation = None
             remote_attestation = None
         else:
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
-            content["attestation"] = local_attestation
-
-            content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
-            try:
-                res = await self.transport_client.create_group(
-                    get_domain_from_id(group_id), group_id, user_id, content
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            remote_attestation = res["attestation"]
-            await self.attestations.verify_attestation(
-                remote_attestation,
-                group_id=group_id,
-                user_id=user_id,
-                server_name=get_domain_from_id(group_id),
-            )
+            raise SynapseError(400, "Unable to create remote groups")
 
         is_publicised = content.get("publicise", False)
         token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def join_group(self, group_id, user_id, content):
+    async def join_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Request to join a group
         """
         if self.is_mine_id(group_id):
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def accept_invite(self, group_id, user_id, content):
+    async def accept_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Accept an invite to a group
         """
         if self.is_mine_id(group_id):
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def invite(self, group_id, user_id, requester_user_id, config):
+    async def invite(
+        self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+    ) -> JsonDict:
         """Invite a user to a group
         """
         content = {"requester_user_id": requester_user_id, "config": config}
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def on_invite(self, group_id, user_id, content):
+    async def on_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """One of our users were invited to a group
         """
         # TODO: Support auto join and rejection
@@ -484,8 +483,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
         return res
 
     async def remove_user_from_group(
-        self, group_id, user_id, requester_user_id, content
-    ):
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Remove a user from a group
         """
         if user_id == requester_user_id:
@@ -518,7 +517,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def user_removed_from_group(self, group_id, user_id, content):
+    async def user_removed_from_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> None:
         """One of our users was removed/kicked from a group
         """
         # TODO: Check if user in group
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..8fc1e8b91c 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -27,9 +27,11 @@ from synapse.api.errors import (
     HttpResponseException,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, Requester
 from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
@@ -46,15 +48,43 @@ class IdentityHandler(BaseHandler):
     def __init__(self, hs):
         super().__init__(hs)
 
+        # An HTTP client for contacting trusted URLs.
         self.http_client = SimpleHttpClient(hs)
-        # We create a blacklisting instance of SimpleHttpClient for contacting identity
-        # servers specified by clients
+        # An HTTP client for contacting identity servers specified by clients.
         self.blacklisting_http_client = SimpleHttpClient(
             hs, ip_blacklist=hs.config.federation_ip_range_blacklist
         )
-        self.federation_http_client = hs.get_http_client()
+        self.federation_http_client = hs.get_federation_http_client()
         self.hs = hs
 
+        self._web_client_location = hs.config.invite_client_location
+
+        # Ratelimiters for `/requestToken` endpoints.
+        self._3pid_validation_ratelimiter_ip = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+        self._3pid_validation_ratelimiter_address = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+
+    def ratelimit_request_token_requests(
+        self, request: SynapseRequest, medium: str, address: str,
+    ):
+        """Used to ratelimit requests to `/requestToken` by IP and address.
+
+        Args:
+            request: The associated request
+            medium: The type of threepid, e.g. "msisdn" or "email"
+            address: The actual threepid ID, e.g. the phone number or email address
+        """
+
+        self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
     async def threepid_from_creds(
         self, id_server: str, creds: Dict[str, str]
     ) -> Optional[JsonDict]:
@@ -474,6 +504,8 @@ class IdentityHandler(BaseHandler):
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
+        # It is already checked that public_baseurl is configured since this code
+        # should only be used if account_threepid_delegate_msisdn is true.
         assert self.hs.config.public_baseurl
 
         # we need to tell the client to send the token back to us, since it doesn't
@@ -803,6 +835,9 @@ class IdentityHandler(BaseHandler):
             "sender_display_name": inviter_display_name,
             "sender_avatar_url": inviter_avatar_url,
         }
+        # If a custom web client location is available, include it in the request.
+        if self._web_client_location:
+            invite_config["org.matrix.web_client_location"] = self._web_client_location
 
         # Add the identity service access token to the JSON body and use the v2
         # Identity Service endpoints if id_access_token is present
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index cb11754bf8..fbd8df9dcc 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
         member_event_id: str,
         is_peeking: bool,
     ) -> JsonDict:
-        room_state = await self.state_store.get_state_for_events([member_event_id])
-
-        room_state = room_state[member_event_id]
+        room_state = await self.state_store.get_state_for_event(member_event_id)
 
         limit = pagin_config.limit if pagin_config else None
         if limit is None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 96843338ae..a15336bf00 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -174,7 +174,7 @@ class MessageHandler:
                 raise NotFoundError("Can't find event for token %s" % (at_token,))
 
             visible_events = await filter_events_for_client(
-                self.storage, user_id, last_events, filter_send_to_client=False
+                self.storage, user_id, last_events, filter_send_to_client=False,
             )
 
             event = last_events[0]
@@ -432,6 +432,8 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
+        self._external_cache = hs.get_external_cache()
+
     async def create_event(
         self,
         requester: Requester,
@@ -744,7 +746,7 @@ class EventCreationHandler:
                 event.sender,
             )
 
-            spam_error = self.spam_checker.check_event_for_spam(event)
+            spam_error = await self.spam_checker.check_event_for_spam(event)
             if spam_error:
                 if not isinstance(spam_error, str):
                     spam_error = "Spam is not permitted here"
@@ -939,6 +941,8 @@ class EventCreationHandler:
 
         await self.action_generator.handle_push_actions_for_event(event, context)
 
+        await self.cache_joined_hosts_for_event(event)
+
         try:
             # If we're a worker we need to hit out to the master.
             writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
             await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
+    async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+        """Precalculate the joined hosts at the event, when using Redis, so that
+        external federation senders don't have to recalculate it themselves.
+        """
+
+        if not self._external_cache.is_enabled():
+            return
+
+        # We actually store two mappings, event ID -> prev state group,
+        # state group -> joined hosts, which is much more space efficient
+        # than event ID -> joined hosts.
+        #
+        # Note: We have to cache event ID -> prev state group, as we don't
+        # store that in the DB.
+        #
+        # Note: We always set the state group -> joined hosts cache, even if
+        # we already set it, so that the expiry time is reset.
+
+        state_entry = await self.state.resolve_state_groups_for_events(
+            event.room_id, event_ids=event.prev_event_ids()
+        )
+
+        if state_entry.state_group:
+            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+            await self._external_cache.set(
+                "event_to_prev_state_group",
+                event.event_id,
+                state_entry.state_group,
+                expiry_ms=60 * 60 * 1000,
+            )
+            await self._external_cache.set(
+                "get_joined_hosts",
+                str(state_entry.state_group),
+                list(joined_hosts),
+                expiry_ms=60 * 60 * 1000,
+            )
+
     async def _validate_canonical_alias(
         self, directory_handler, room_alias_str: str, expected_room_id: str
     ) -> None:
@@ -1261,7 +1303,7 @@ class EventCreationHandler:
                 event, context = await self.create_event(
                     requester,
                     {
-                        "type": "org.matrix.dummy_event",
+                        "type": EventTypes.Dummy,
                         "content": {},
                         "room_id": room_id,
                         "sender": user_id,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..71008ec50d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import inspect
 import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
 from urllib.parse import urlencode
 
 import attr
@@ -35,7 +35,7 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.handlers._base import BaseHandler
+from synapse.config.oidc_config import OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -71,6 +71,144 @@ JWK = Dict[str, str]
 JWKS = TypedDict("JWKS", {"keys": List[JWK]})
 
 
+class OidcHandler:
+    """Handles requests related to the OpenID Connect login flow.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._sso_handler = hs.get_sso_handler()
+
+        provider_confs = hs.config.oidc.oidc_providers
+        # we should not have been instantiated if there is no configured provider.
+        assert provider_confs
+
+        self._token_generator = OidcSessionTokenGenerator(hs)
+        self._providers = {
+            p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+        }  # type: Dict[str, OidcProvider]
+
+    async def load_metadata(self) -> None:
+        """Validate the config and load the metadata from the remote endpoint.
+
+        Called at startup to ensure we have everything we need.
+        """
+        for idp_id, p in self._providers.items():
+            try:
+                await p.load_metadata()
+                await p.load_jwks()
+            except Exception as e:
+                raise Exception(
+                    "Error while initialising OIDC provider %r" % (idp_id,)
+                ) from e
+
+    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/callback
+
+        Since we might want to display OIDC-related errors in a user-friendly
+        way, we don't raise SynapseError from here. Instead, we call
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+        Most of the OpenID Connect logic happens here:
+
+          - first, we check if there was any error returned by the provider and
+            display it
+          - then we fetch the session cookie, decode and verify it
+          - the ``state`` query parameter should match with the one stored in the
+            session cookie
+
+        Once we know the session is legit, we then delegate to the OIDC Provider
+        implementation, which will exchange the code with the provider and complete the
+        login/authentication.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+
+        # The provider might redirect with an error.
+        # In that case, just display it as-is.
+        if b"error" in request.args:
+            # error response from the auth server. see:
+            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+            error = request.args[b"error"][0].decode()
+            description = request.args.get(b"error_description", [b""])[0].decode()
+
+            # Most of the errors returned by the provider could be due by
+            # either the provider misbehaving or Synapse being misconfigured.
+            # The only exception of that is "access_denied", where the user
+            # probably cancelled the login flow. In other cases, log those errors.
+            if error != "access_denied":
+                logger.error("Error from the OIDC provider: %s %s", error, description)
+
+            self._sso_handler.render_error(request, error, description)
+            return
+
+        # otherwise, it is presumably a successful response. see:
+        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+        # Fetch the session cookie
+        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
+        if session is None:
+            logger.info("No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
+            return
+
+        # Remove the cookie. There is a good chance that if the callback failed
+        # once, it will fail next time and the code will already be exchanged.
+        # Removing it early avoids spamming the provider with token requests.
+        request.addCookie(
+            SESSION_COOKIE_NAME,
+            b"",
+            path="/_synapse/oidc",
+            expires="Thu, Jan 01 1970 00:00:00 UTC",
+            httpOnly=True,
+            sameSite="lax",
+        )
+
+        # Check for the state query parameter
+        if b"state" not in request.args:
+            logger.info("State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
+            return
+
+        state = request.args[b"state"][0].decode()
+
+        # Deserialize the session token and verify it.
+        try:
+            session_data = self._token_generator.verify_oidc_session_token(
+                session, state
+            )
+        except (MacaroonDeserializationException, ValueError) as e:
+            logger.exception("Invalid session")
+            self._sso_handler.render_error(request, "invalid_session", str(e))
+            return
+        except MacaroonInvalidSignatureException as e:
+            logger.exception("Could not verify session")
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
+            return
+
+        oidc_provider = self._providers.get(session_data.idp_id)
+        if not oidc_provider:
+            logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+            self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+            return
+
+        if b"code" not in request.args:
+            logger.info("Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
+            return
+
+        code = request.args[b"code"][0].decode()
+
+        await oidc_provider.handle_oidc_callback(request, session_data, code)
+
+
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint
     """
@@ -85,46 +223,64 @@ class OidcError(Exception):
         return self.error
 
 
-class OidcHandler(BaseHandler):
-    """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+    """Wraps the config for a single OIDC IdentityProvider
+
+    Provides methods for handling redirect requests and callbacks via that particular
+    IdP.
     """
 
-    def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+    def __init__(
+        self,
+        hs: "HomeServer",
+        token_generator: "OidcSessionTokenGenerator",
+        provider: OidcProviderConfig,
+    ):
+        self._store = hs.get_datastore()
+
+        self._token_generator = token_generator
+
         self._callback_url = hs.config.oidc_callback_url  # type: str
-        self._scopes = hs.config.oidc_scopes  # type: List[str]
-        self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
+
+        self._scopes = provider.scopes
+        self._user_profile_method = provider.user_profile_method
         self._client_auth = ClientAuth(
-            hs.config.oidc_client_id,
-            hs.config.oidc_client_secret,
-            hs.config.oidc_client_auth_method,
+            provider.client_id, provider.client_secret, provider.client_auth_method,
         )  # type: ClientAuth
-        self._client_auth_method = hs.config.oidc_client_auth_method  # type: str
+        self._client_auth_method = provider.client_auth_method
         self._provider_metadata = OpenIDProviderMetadata(
-            issuer=hs.config.oidc_issuer,
-            authorization_endpoint=hs.config.oidc_authorization_endpoint,
-            token_endpoint=hs.config.oidc_token_endpoint,
-            userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
-            jwks_uri=hs.config.oidc_jwks_uri,
+            issuer=provider.issuer,
+            authorization_endpoint=provider.authorization_endpoint,
+            token_endpoint=provider.token_endpoint,
+            userinfo_endpoint=provider.userinfo_endpoint,
+            jwks_uri=provider.jwks_uri,
         )  # type: OpenIDProviderMetadata
-        self._provider_needs_discovery = hs.config.oidc_discover  # type: bool
-        self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
-            hs.config.oidc_user_mapping_provider_config
-        )  # type: OidcMappingProvider
-        self._skip_verification = hs.config.oidc_skip_verification  # type: bool
-        self._allow_existing_users = hs.config.oidc_allow_existing_users  # type: bool
+        self._provider_needs_discovery = provider.discover
+        self._user_mapping_provider = provider.user_mapping_provider_class(
+            provider.user_mapping_provider_config
+        )
+        self._skip_verification = provider.skip_verification
+        self._allow_existing_users = provider.allow_existing_users
 
         self._http_client = hs.get_proxied_http_client()
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
         self._server_name = hs.config.server_name  # type: str
-        self._macaroon_secret_key = hs.config.macaroon_secret_key
 
         # identifier for the external_ids table
-        self._auth_provider_id = "oidc"
+        self.idp_id = provider.idp_id
+
+        # user-facing name of this auth provider
+        self.idp_name = provider.idp_name
+
+        # MXC URI for icon for this auth provider
+        self.idp_icon = provider.idp_icon
+
+        # optional brand identifier for this auth provider
+        self.idp_brand = provider.idp_brand
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _validate_metadata(self):
         """Verifies the provider metadata.
 
@@ -477,7 +633,7 @@ class OidcHandler(BaseHandler):
     async def handle_redirect_request(
         self,
         request: SynapseRequest,
-        client_redirect_url: bytes,
+        client_redirect_url: Optional[bytes],
         ui_auth_session_id: Optional[str] = None,
     ) -> str:
         """Handle an incoming request to /login/sso/redirect
@@ -487,7 +643,7 @@ class OidcHandler(BaseHandler):
 
           - ``client_id``: the client ID set in ``oidc_config.client_id``
           - ``response_type``: ``code``
-          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
           - ``scope``: the list of scopes set in ``oidc_config.scopes``
           - ``state``: a random string
           - ``nonce``: a random string
@@ -501,7 +657,7 @@ class OidcHandler(BaseHandler):
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
-                when everything is done
+                when everything is done (or None for UI Auth)
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
@@ -513,16 +669,22 @@ class OidcHandler(BaseHandler):
         state = generate_token()
         nonce = generate_token()
 
-        cookie = self._generate_oidc_session_token(
+        if not client_redirect_url:
+            client_redirect_url = b""
+
+        cookie = self._token_generator.generate_oidc_session_token(
             state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url.decode(),
-            ui_auth_session_id=ui_auth_session_id,
+            session_data=OidcSessionData(
+                idp_id=self.idp_id,
+                nonce=nonce,
+                client_redirect_url=client_redirect_url.decode(),
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
         request.addCookie(
             SESSION_COOKIE_NAME,
             cookie,
-            path="/_synapse/oidc",
+            path="/_synapse/client/oidc",
             max_age="3600",
             httpOnly=True,
             sameSite="lax",
@@ -540,22 +702,16 @@ class OidcHandler(BaseHandler):
             nonce=nonce,
         )
 
-    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_synapse/oidc/callback
+    async def handle_oidc_callback(
+        self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+    ) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/callback
 
-        Since we might want to display OIDC-related errors in a user-friendly
-        way, we don't raise SynapseError from here. Instead, we call
-        ``self._sso_handler.render_error`` which displays an HTML page for the error.
-
-        Most of the OpenID Connect logic happens here:
+        By this time we have already validated the session on the synapse side, and
+        now need to do the provider-specific operations. This includes:
 
-          - first, we check if there was any error returned by the provider and
-            display it
-          - then we fetch the session cookie, decode and verify it
-          - the ``state`` query parameter should match with the one stored in the
-            session cookie
-          - once we known this session is legit, exchange the code with the
-            provider using the ``token_endpoint`` (see ``_exchange_code``)
+          - exchange the code with the provider using the ``token_endpoint`` (see
+            ``_exchange_code``)
           - once we have the token, use it to either extract the UserInfo from
             the ``id_token`` (``_parse_id_token``), or use the ``access_token``
             to fetch UserInfo from the ``userinfo_endpoint``
@@ -565,88 +721,12 @@ class OidcHandler(BaseHandler):
 
         Args:
             request: the incoming request from the browser.
+            session_data: the session data, extracted from our cookie
+            code: The authorization code we got from the callback.
         """
-
-        # The provider might redirect with an error.
-        # In that case, just display it as-is.
-        if b"error" in request.args:
-            # error response from the auth server. see:
-            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
-            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
-            error = request.args[b"error"][0].decode()
-            description = request.args.get(b"error_description", [b""])[0].decode()
-
-            # Most of the errors returned by the provider could be due by
-            # either the provider misbehaving or Synapse being misconfigured.
-            # The only exception of that is "access_denied", where the user
-            # probably cancelled the login flow. In other cases, log those errors.
-            if error != "access_denied":
-                logger.error("Error from the OIDC provider: %s %s", error, description)
-
-            self._sso_handler.render_error(request, error, description)
-            return
-
-        # otherwise, it is presumably a successful response. see:
-        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
-
-        # Fetch the session cookie
-        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
-        if session is None:
-            logger.info("No session cookie found")
-            self._sso_handler.render_error(
-                request, "missing_session", "No session cookie found"
-            )
-            return
-
-        # Remove the cookie. There is a good chance that if the callback failed
-        # once, it will fail next time and the code will already be exchanged.
-        # Removing it early avoids spamming the provider with token requests.
-        request.addCookie(
-            SESSION_COOKIE_NAME,
-            b"",
-            path="/_synapse/oidc",
-            expires="Thu, Jan 01 1970 00:00:00 UTC",
-            httpOnly=True,
-            sameSite="lax",
-        )
-
-        # Check for the state query parameter
-        if b"state" not in request.args:
-            logger.info("State parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "State parameter is missing"
-            )
-            return
-
-        state = request.args[b"state"][0].decode()
-
-        # Deserialize the session token and verify it.
-        try:
-            (
-                nonce,
-                client_redirect_url,
-                ui_auth_session_id,
-            ) = self._verify_oidc_session_token(session, state)
-        except MacaroonDeserializationException as e:
-            logger.exception("Invalid session")
-            self._sso_handler.render_error(request, "invalid_session", str(e))
-            return
-        except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session")
-            self._sso_handler.render_error(request, "mismatching_session", str(e))
-            return
-
         # Exchange the code with the provider
-        if b"code" not in request.args:
-            logger.info("Code parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "Code parameter is missing"
-            )
-            return
-
-        logger.debug("Exchanging code")
-        code = request.args[b"code"][0].decode()
         try:
+            logger.debug("Exchanging code")
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")
@@ -668,25 +748,131 @@ class OidcHandler(BaseHandler):
         else:
             logger.debug("Extracting userinfo from id_token")
             try:
-                userinfo = await self._parse_id_token(token, nonce=nonce)
+                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
+        # first check if we're doing a UIA
+        if session_data.ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_userinfo(userinfo)
+            except Exception as e:
+                logger.exception("Could not extract remote user id")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
+            )
+
+        # otherwise, it's a login
 
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_userinfo_to_user(
-                userinfo, token, user_agent, ip_address
+            await self._complete_oidc_login(
+                userinfo, token, request, session_data.client_redirect_url
             )
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
-            return
+
+    async def _complete_oidc_login(
+        self,
+        userinfo: UserInfo,
+        token: Token,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ) -> None:
+        """Given a UserInfo response, complete the login flow
+
+        UserInfo should have a claim that uniquely identifies users. This claim
+        is usually `sub`, but can be configured with `oidc_config.subject_claim`.
+        It is then used as an `external_id`.
+
+        If we don't find the user that way, we should register the user,
+        mapping the localpart and the display name from the UserInfo.
+
+        If a user already exists with the mxid we've mapped and allow_existing_users
+        is disabled, raise an exception.
+
+        Otherwise, render a redirect back to the client_redirect_url with a loginToken.
+
+        Args:
+            userinfo: an object representing the user
+            token: a dict with the tokens obtained from the provider
+            request: The request to respond to
+            client_redirect_url: The redirect URL passed in by the client.
+
+        Raises:
+            MappingException: if there was an error while mapping some properties
+        """
+        try:
+            remote_user_id = self._remote_id_from_userinfo(userinfo)
+        except Exception as e:
+            raise MappingException(
+                "Failed to extract subject from OIDC response: %s" % (e,)
+            )
+
+        # Older mapping providers don't accept the `failures` argument, so we
+        # try and detect support.
+        mapper_signature = inspect.signature(
+            self._user_mapping_provider.map_user_attributes
+        )
+        supports_failures = "failures" in mapper_signature.parameters
+
+        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Call the mapping provider to map the OIDC userinfo and token to user attributes.
+
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            if supports_failures:
+                attributes = await self._user_mapping_provider.map_user_attributes(
+                    userinfo, token, failures
+                )
+            else:
+                # If the mapping provider does not support processing failures,
+                # do not continually generate the same Matrix ID since it will
+                # continue to already be in use. Note that the error raised is
+                # arbitrary and will get turned into a MappingException.
+                if failures:
+                    raise MappingException(
+                        "Mapping provider does not support de-duplicating Matrix IDs"
+                    )
+
+                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
+                    userinfo, token
+                )
+
+            return UserAttributes(**attributes)
+
+        async def grandfather_existing_users() -> Optional[str]:
+            if self._allow_existing_users:
+                # If allowing existing users we want to generate a single localpart
+                # and attempt to match it.
+                attributes = await oidc_response_to_user_attributes(failures=0)
+
+                user_id = UserID(attributes.localpart, self._server_name).to_string()
+                users = await self._store.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    # If an existing matrix ID is returned, then use it.
+                    if len(users) == 1:
+                        previously_registered_user_id = next(iter(users))
+                    elif user_id in users:
+                        previously_registered_user_id = user_id
+                    else:
+                        # Do not attempt to continue generating Matrix IDs.
+                        raise MappingException(
+                            "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                                user_id, users
+                            )
+                        )
+
+                    return previously_registered_user_id
+
+            return None
 
         # Mapping providers might not have get_extra_attributes: only call this
         # method if it exists.
@@ -697,22 +883,42 @@ class OidcHandler(BaseHandler):
         if get_extra_attributes:
             extra_attributes = await get_extra_attributes(userinfo, token)
 
-        # and finally complete the login
-        if ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, ui_auth_session_id, request
-            )
-        else:
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url, extra_attributes
-            )
+        await self._sso_handler.complete_sso_login_request(
+            self.idp_id,
+            remote_user_id,
+            request,
+            client_redirect_url,
+            oidc_response_to_user_attributes,
+            grandfather_existing_users,
+            extra_attributes,
+        )
+
+    def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+        """Extract the unique remote id from an OIDC UserInfo block
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+        Returns:
+            remote user id
+        """
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+        # Some OIDC providers use integer IDs, but Synapse expects external IDs
+        # to be strings.
+        return str(remote_user_id)
+
+
+class OidcSessionTokenGenerator:
+    """Methods for generating and checking OIDC Session cookies."""
+
+    def __init__(self, hs: "HomeServer"):
+        self._clock = hs.get_clock()
+        self._server_name = hs.hostname
+        self._macaroon_secret_key = hs.config.key.macaroon_secret_key
 
-    def _generate_oidc_session_token(
+    def generate_oidc_session_token(
         self,
         state: str,
-        nonce: str,
-        client_redirect_url: str,
-        ui_auth_session_id: Optional[str],
+        session_data: "OidcSessionData",
         duration_in_ms: int = (60 * 60 * 1000),
     ) -> str:
         """Generates a signed token storing data about an OIDC session.
@@ -725,11 +931,7 @@ class OidcHandler(BaseHandler):
 
         Args:
             state: The ``state`` parameter passed to the OIDC provider.
-            nonce: The ``nonce`` parameter passed to the OIDC provider.
-            client_redirect_url: The URL the client gave when it initiated the
-                flow.
-            ui_auth_session_id: The session ID of the ongoing UI Auth (or
-                None if this is a login).
+            session_data: data to include in the session token.
             duration_in_ms: An optional duration for the token in milliseconds.
                 Defaults to an hour.
 
@@ -742,23 +944,24 @@ class OidcHandler(BaseHandler):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = session")
         macaroon.add_first_party_caveat("state = %s" % (state,))
-        macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
+        macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
+        macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
         macaroon.add_first_party_caveat(
-            "client_redirect_url = %s" % (client_redirect_url,)
+            "client_redirect_url = %s" % (session_data.client_redirect_url,)
         )
-        if ui_auth_session_id:
+        if session_data.ui_auth_session_id:
             macaroon.add_first_party_caveat(
-                "ui_auth_session_id = %s" % (ui_auth_session_id,)
+                "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
             )
-        now = self.clock.time_msec()
+        now = self._clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
 
         return macaroon.serialize()
 
-    def _verify_oidc_session_token(
+    def verify_oidc_session_token(
         self, session: bytes, state: str
-    ) -> Tuple[str, str, Optional[str]]:
+    ) -> "OidcSessionData":
         """Verifies and extract an OIDC session token.
 
         This verifies that a given session token was issued by this homeserver
@@ -769,7 +972,10 @@ class OidcHandler(BaseHandler):
             state: The state the OIDC provider gave back
 
         Returns:
-            The nonce, client_redirect_url, and ui_auth_session_id for this session
+            The data extracted from the session cookie
+
+        Raises:
+            ValueError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -778,6 +984,7 @@ class OidcHandler(BaseHandler):
         v.satisfy_exact("type = session")
         v.satisfy_exact("state = %s" % (state,))
         v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
         # Sometimes there's a UI auth session ID, it seems to be OK to attempt
         # to always satisfy this.
@@ -786,9 +993,9 @@ class OidcHandler(BaseHandler):
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-        # Extract the `nonce`, `client_redirect_url`, and maybe the
-        # `ui_auth_session_id` from the token.
+        # Extract the session data from the token.
         nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
         client_redirect_url = self._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
@@ -799,7 +1006,12 @@ class OidcHandler(BaseHandler):
         except ValueError:
             ui_auth_session_id = None
 
-        return nonce, client_redirect_url, ui_auth_session_id
+        return OidcSessionData(
+            nonce=nonce,
+            idp_id=idp_id,
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=ui_auth_session_id,
+        )
 
     def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
         """Extracts a caveat value from a macaroon token.
@@ -812,7 +1024,7 @@ class OidcHandler(BaseHandler):
             The extracted value
 
         Raises:
-            Exception: if the caveat was not in the macaroon
+            ValueError: if the caveat was not in the macaroon
         """
         prefix = key + " = "
         for caveat in macaroon.caveats:
@@ -825,117 +1037,30 @@ class OidcHandler(BaseHandler):
         if not caveat.startswith(prefix):
             return False
         expiry = int(caveat[len(prefix) :])
-        now = self.clock.time_msec()
+        now = self._clock.time_msec()
         return now < expiry
 
-    async def _map_userinfo_to_user(
-        self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
-    ) -> str:
-        """Maps a UserInfo object to a mxid.
-
-        UserInfo should have a claim that uniquely identifies users. This claim
-        is usually `sub`, but can be configured with `oidc_config.subject_claim`.
-        It is then used as an `external_id`.
-
-        If we don't find the user that way, we should register the user,
-        mapping the localpart and the display name from the UserInfo.
-
-        If a user already exists with the mxid we've mapped and allow_existing_users
-        is disabled, raise an exception.
-
-        Args:
-            userinfo: an object representing the user
-            token: a dict with the tokens obtained from the provider
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Raises:
-            MappingException: if there was an error while mapping some properties
-
-        Returns:
-            The mxid of the user
-        """
-        try:
-            remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
-        except Exception as e:
-            raise MappingException(
-                "Failed to extract subject from OIDC response: %s" % (e,)
-            )
-        # Some OIDC providers use integer IDs, but Synapse expects external IDs
-        # to be strings.
-        remote_user_id = str(remote_user_id)
-
-        # Older mapping providers don't accept the `failures` argument, so we
-        # try and detect support.
-        mapper_signature = inspect.signature(
-            self._user_mapping_provider.map_user_attributes
-        )
-        supports_failures = "failures" in mapper_signature.parameters
-
-        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
-            """
-            Call the mapping provider to map the OIDC userinfo and token to user attributes.
-
-            This is backwards compatibility for abstraction for the SSO handler.
-            """
-            if supports_failures:
-                attributes = await self._user_mapping_provider.map_user_attributes(
-                    userinfo, token, failures
-                )
-            else:
-                # If the mapping provider does not support processing failures,
-                # do not continually generate the same Matrix ID since it will
-                # continue to already be in use. Note that the error raised is
-                # arbitrary and will get turned into a MappingException.
-                if failures:
-                    raise MappingException(
-                        "Mapping provider does not support de-duplicating Matrix IDs"
-                    )
-
-                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
-                    userinfo, token
-                )
 
-            return UserAttributes(**attributes)
+@attr.s(frozen=True, slots=True)
+class OidcSessionData:
+    """The attributes which are stored in a OIDC session cookie"""
 
-        async def grandfather_existing_users() -> Optional[str]:
-            if self._allow_existing_users:
-                # If allowing existing users we want to generate a single localpart
-                # and attempt to match it.
-                attributes = await oidc_response_to_user_attributes(failures=0)
+    # the Identity Provider being used
+    idp_id = attr.ib(type=str)
 
-                user_id = UserID(attributes.localpart, self.server_name).to_string()
-                users = await self.store.get_users_by_id_case_insensitive(user_id)
-                if users:
-                    # If an existing matrix ID is returned, then use it.
-                    if len(users) == 1:
-                        previously_registered_user_id = next(iter(users))
-                    elif user_id in users:
-                        previously_registered_user_id = user_id
-                    else:
-                        # Do not attempt to continue generating Matrix IDs.
-                        raise MappingException(
-                            "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
-                                user_id, users
-                            )
-                        )
+    # The `nonce` parameter passed to the OIDC provider.
+    nonce = attr.ib(type=str)
 
-                    return previously_registered_user_id
+    # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
+    client_redirect_url = attr.ib(type=str)
 
-            return None
-
-        return await self._sso_handler.get_mxid_from_sso(
-            self._auth_provider_id,
-            remote_user_id,
-            user_agent,
-            ip_address,
-            oidc_response_to_user_attributes,
-            grandfather_existing_users,
-        )
+    # The session ID of the ongoing UI Auth (None if this is a login)
+    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
+    "UserAttributeDict",
+    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
 )
 C = TypeVar("C")
 
@@ -1014,12 +1139,13 @@ def jinja_finalize(thing):
 env = Environment(finalize=jinja_finalize)
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class JinjaOidcMappingConfig:
-    subject_claim = attr.ib()  # type: str
-    localpart_template = attr.ib()  # type: Template
-    display_name_template = attr.ib()  # type: Optional[Template]
-    extra_attributes = attr.ib()  # type: Dict[str, Template]
+    subject_claim = attr.ib(type=str)
+    localpart_template = attr.ib(type=Optional[Template])
+    display_name_template = attr.ib(type=Optional[Template])
+    email_template = attr.ib(type=Optional[Template])
+    extra_attributes = attr.ib(type=Dict[str, Template])
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1035,50 +1161,37 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        if "localpart_template" not in config:
-            raise ConfigError(
-                "missing key: oidc_config.user_mapping_provider.config.localpart_template"
-            )
-
-        try:
-            localpart_template = env.from_string(config["localpart_template"])
-        except Exception as e:
-            raise ConfigError(
-                "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
-                % (e,)
-            )
-
-        display_name_template = None  # type: Optional[Template]
-        if "display_name_template" in config:
+        def parse_template_config(option_name: str) -> Optional[Template]:
+            if option_name not in config:
+                return None
             try:
-                display_name_template = env.from_string(config["display_name_template"])
+                return env.from_string(config[option_name])
             except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
-                    % (e,)
-                )
+                raise ConfigError("invalid jinja template", path=[option_name]) from e
+
+        localpart_template = parse_template_config("localpart_template")
+        display_name_template = parse_template_config("display_name_template")
+        email_template = parse_template_config("email_template")
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
             extra_attributes_config = config.get("extra_attributes") or {}
             if not isinstance(extra_attributes_config, dict):
-                raise ConfigError(
-                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
-                )
+                raise ConfigError("must be a dict", path=["extra_attributes"])
 
             for key, value in extra_attributes_config.items():
                 try:
                     extra_attributes[key] = env.from_string(value)
                 except Exception as e:
                     raise ConfigError(
-                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
-                        % (key, e)
-                    )
+                        "invalid jinja template", path=["extra_attributes", key]
+                    ) from e
 
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            email_template=email_template,
             extra_attributes=extra_attributes,
         )
 
@@ -1088,25 +1201,35 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token, failures: int
     ) -> UserAttributeDict:
-        localpart = self._config.localpart_template.render(user=userinfo).strip()
+        localpart = None
 
-        # Ensure only valid characters are included in the MXID.
-        localpart = map_username_to_mxid_localpart(localpart)
+        if self._config.localpart_template:
+            localpart = self._config.localpart_template.render(user=userinfo).strip()
 
-        # Append suffix integer if last call to this function failed to produce
-        # a usable mxid.
-        localpart += str(failures) if failures else ""
+            # Ensure only valid characters are included in the MXID.
+            localpart = map_username_to_mxid_localpart(localpart)
 
-        display_name = None  # type: Optional[str]
-        if self._config.display_name_template is not None:
-            display_name = self._config.display_name_template.render(
-                user=userinfo
-            ).strip()
+            # Append suffix integer if last call to this function failed to produce
+            # a usable mxid.
+            localpart += str(failures) if failures else ""
 
-            if display_name == "":
-                display_name = None
+        def render_template_field(template: Optional[Template]) -> Optional[str]:
+            if template is None:
+                return None
+            return template.render(user=userinfo).strip()
 
-        return UserAttributeDict(localpart=localpart, display_name=display_name)
+        display_name = render_template_field(self._config.display_name_template)
+        if display_name == "":
+            display_name = None
+
+        emails = []  # type: List[str]
+        email = render_template_field(self._config.email_template)
+        if email:
+            emails.append(email)
+
+        return UserAttributeDict(
+            localpart=localpart, display_name=display_name, emails=emails
+        )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dee0ef45e7..c02b951031 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["displayname"]
+            return result.get("displayname")
 
     async def set_displayname(
         self,
@@ -246,7 +246,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["avatar_url"]
+            return result.get("avatar_url")
 
     async def set_avatar_url(
         self,
@@ -286,13 +286,19 @@ class ProfileHandler(BaseHandler):
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
             )
 
+        avatar_url_to_set = new_avatar_url  # type: Optional[str]
+        if new_avatar_url == "":
+            avatar_url_to_set = None
+
         # Same like set_displayname
         if by_admin:
             requester = create_requester(
                 target_user, authenticated_entity=requester.authenticated_entity
             )
 
-        await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+        await self.store.set_profile_avatar_url(
+            target_user.localpart, avatar_url_to_set
+        )
 
         if self.hs.config.user_directory_search_all_users:
             profile = await self.store.get_profileinfo(target_user.localpart)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index a7550806e6..6bb2fd936b 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
         super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
+        self.account_data_handler = hs.get_account_data_handler()
         self.read_marker_linearizer = Linearizer(name="read_marker")
-        self.notifier = hs.get_notifier()
 
     async def received_client_read_marker(
         self, room_id: str, user_id: str, event_id: str
@@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
 
             if should_update:
                 content = {"event_id": event_id}
-                max_id = await self.store.add_account_data_to_room(
+                await self.account_data_handler.add_account_data_to_room(
                     user_id, room_id, "m.fully_read", content
                 )
-                self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7b9..cc21fc2284 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -13,31 +13,49 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.appservice import ApplicationService
 from synapse.handlers._base import BaseHandler
 from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class ReceiptsHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.hs = hs
-        self.federation = hs.get_federation_sender()
-        hs.get_federation_registry().register_edu_handler(
-            "m.receipt", self._received_remote_receipt
-        )
+
+        # We only need to poke the federation sender explicitly if its on the
+        # same instance. Other federation sender instances will get notified by
+        # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
+        # in the receipts stream.
+        self.federation_sender = None
+        if hs.should_send_federation():
+            self.federation_sender = hs.get_federation_sender()
+
+        # If we can handle the receipt EDUs we do so, otherwise we route them
+        # to the appropriate worker.
+        if hs.get_instance_name() in hs.config.worker.writers.receipts:
+            hs.get_federation_registry().register_edu_handler(
+                "m.receipt", self._received_remote_receipt
+            )
+        else:
+            hs.get_federation_registry().register_instances_for_edu(
+                "m.receipt", hs.config.worker.writers.receipts,
+            )
+
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
-    async def _received_remote_receipt(self, origin, content):
+    async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
         """Called when we receive an EDU of type m.receipt from a remote HS.
         """
         receipts = []
@@ -64,11 +82,11 @@ class ReceiptsHandler(BaseHandler):
 
         await self._handle_new_receipts(receipts)
 
-    async def _handle_new_receipts(self, receipts):
+    async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
         """Takes a list of receipts, stores them and informs the notifier.
         """
-        min_batch_id = None
-        max_batch_id = None
+        min_batch_id = None  # type: Optional[int]
+        max_batch_id = None  # type: Optional[int]
 
         for receipt in receipts:
             res = await self.store.insert_receipt(
@@ -90,7 +108,8 @@ class ReceiptsHandler(BaseHandler):
             if max_batch_id is None or max_persisted_id > max_batch_id:
                 max_batch_id = max_persisted_id
 
-        if min_batch_id is None:
+        # Either both of these should be None or neither.
+        if min_batch_id is None or max_batch_id is None:
             # no new receipts
             return False
 
@@ -98,15 +117,15 @@ class ReceiptsHandler(BaseHandler):
 
         self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
         # Note that the min here shouldn't be relied upon to be accurate.
-        await maybe_awaitable(
-            self.hs.get_pusherpool().on_new_receipts(
-                min_batch_id, max_batch_id, affected_room_ids
-            )
+        await self.hs.get_pusherpool().on_new_receipts(
+            min_batch_id, max_batch_id, affected_room_ids
         )
 
         return True
 
-    async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+    async def received_client_receipt(
+        self, room_id: str, receipt_type: str, user_id: str, event_id: str
+    ) -> None:
         """Called when a client tells us a local user has read up to the given
         event_id in the room.
         """
@@ -122,14 +141,17 @@ class ReceiptsHandler(BaseHandler):
         if not is_new:
             return
 
-        await self.federation.send_read_receipt(receipt)
+        if self.federation_sender:
+            await self.federation_sender.send_read_receipt(receipt)
 
 
 class ReceiptEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: List[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         from_key = int(from_key)
         to_key = self.get_current_key()
 
@@ -174,5 +196,5 @@ class ReceiptEventSource:
 
         return (events, to_key)
 
-    def get_current_key(self, direction="f"):
+    def get_current_key(self, direction: str = "f") -> int:
         return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0d85fd0868..49b085269b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -14,8 +14,9 @@
 # limitations under the License.
 
 """Contains functions for registering clients."""
+
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler):
         user_type: Optional[str] = None,
         default_display_name: Optional[str] = None,
         address: Optional[str] = None,
-        bind_emails: List[str] = [],
+        bind_emails: Iterable[str] = [],
         by_admin: bool = False,
         user_agent_ips: Optional[List[Tuple[str, str]]] = None,
     ) -> str:
@@ -187,7 +188,7 @@ class RegistrationHandler(BaseHandler):
         """
         self.check_registration_ratelimit(address)
 
-        result = self.spam_checker.check_registration_for_spam(
+        result = await self.spam_checker.check_registration_for_spam(
             threepid, localpart, user_agent_ips or [],
         )
 
@@ -630,6 +631,7 @@ class RegistrationHandler(BaseHandler):
         device_id: Optional[str],
         initial_display_name: Optional[str],
         is_guest: bool = False,
+        is_appservice_ghost: bool = False,
     ) -> Tuple[str, str]:
         """Register a device for a user and generate an access token.
 
@@ -651,6 +653,7 @@ class RegistrationHandler(BaseHandler):
                 device_id=device_id,
                 initial_display_name=initial_display_name,
                 is_guest=is_guest,
+                is_appservice_ghost=is_appservice_ghost,
             )
             return r["device_id"], r["access_token"]
 
@@ -672,7 +675,10 @@ class RegistrationHandler(BaseHandler):
             )
         else:
             access_token = await self._auth_handler.get_access_token_for_user_id(
-                user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
+                user_id,
+                device_id=registered_device_id,
+                valid_until_ms=valid_until_ms,
+                is_appservice_ghost=is_appservice_ghost,
             )
 
         return (registered_device_id, access_token)
@@ -688,6 +694,8 @@ class RegistrationHandler(BaseHandler):
             access_token: The access token of the newly logged in device, or
                 None if `inhibit_login` enabled.
         """
+        # TODO: 3pid registration can actually happen on the workers. Consider
+        # refactoring it.
         if self.hs.config.worker_app:
             await self._post_registration_client(
                 user_id=user_id, auth_result=auth_result, access_token=access_token
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 930047e730..07b2187eb1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
 
 from synapse.api.constants import (
     EventTypes,
+    HistoryVisibility,
     JoinRules,
     Membership,
     RoomCreationPreset,
@@ -37,7 +38,6 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
-from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
@@ -54,6 +54,7 @@ from synapse.types import (
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_and_validate_server_name
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -81,21 +82,21 @@ class RoomCreationHandler(BaseHandler):
         self._presets_dict = {
             RoomCreationPreset.PRIVATE_CHAT: {
                 "join_rules": JoinRules.INVITE,
-                "history_visibility": "shared",
+                "history_visibility": HistoryVisibility.SHARED,
                 "original_invitees_have_ops": False,
                 "guest_can_join": True,
                 "power_level_content_override": {"invite": 0},
             },
             RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
                 "join_rules": JoinRules.INVITE,
-                "history_visibility": "shared",
+                "history_visibility": HistoryVisibility.SHARED,
                 "original_invitees_have_ops": True,
                 "guest_can_join": True,
                 "power_level_content_override": {"invite": 0},
             },
             RoomCreationPreset.PUBLIC_CHAT: {
                 "join_rules": JoinRules.PUBLIC,
-                "history_visibility": "shared",
+                "history_visibility": HistoryVisibility.SHARED,
                 "original_invitees_have_ops": False,
                 "guest_can_join": False,
                 "power_level_content_override": {},
@@ -125,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        self._invite_burst_count = (
+            hs.config.ratelimiting.rc_invites_per_room.burst_count
+        )
+
     async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ) -> str:
@@ -358,13 +363,13 @@ class RoomCreationHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        if not self.spam_checker.user_may_create_room(user_id):
+        if not await self.spam_checker.user_may_create_room(user_id):
             raise SynapseError(403, "You are not permitted to create rooms")
 
         creation_content = {
             "room_version": new_room_version.identifier,
             "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
-        }
+        }  # type: JsonDict
 
         # Check if old room was non-federatable
 
@@ -440,6 +445,7 @@ class RoomCreationHandler(BaseHandler):
             invite_list=[],
             initial_state=initial_state,
             creation_content=creation_content,
+            ratelimit=False,
         )
 
         # Transfer membership events
@@ -608,7 +614,7 @@ class RoomCreationHandler(BaseHandler):
                 403, "You are not permitted to create rooms", Codes.FORBIDDEN
             )
 
-        if not is_requester_admin and not self.spam_checker.user_may_create_room(
+        if not is_requester_admin and not await self.spam_checker.user_may_create_room(
             user_id
         ):
             raise SynapseError(403, "You are not permitted to create rooms")
@@ -660,6 +666,9 @@ class RoomCreationHandler(BaseHandler):
             invite_3pid_list = []
             invite_list = []
 
+        if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+            raise SynapseError(400, "Cannot invite so many users at once")
+
         await self.event_creation_handler.assert_accepted_privacy_policy(requester)
 
         power_level_content_override = config.get("power_level_content_override")
@@ -735,6 +744,7 @@ class RoomCreationHandler(BaseHandler):
             room_alias=room_alias,
             power_level_content_override=power_level_content_override,
             creator_join_profile=creator_join_profile,
+            ratelimit=ratelimit,
         )
 
         if "name" in config:
@@ -838,6 +848,7 @@ class RoomCreationHandler(BaseHandler):
         room_alias: Optional[RoomAlias] = None,
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
+        ratelimit: bool = True,
     ) -> int:
         """Sends the initial events into a new room.
 
@@ -884,7 +895,7 @@ class RoomCreationHandler(BaseHandler):
             creator.user,
             room_id,
             "join",
-            ratelimit=False,
+            ratelimit=ratelimit,
             content=creator_join_profile,
         )
 
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 4a13c8e912..14f14db449 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,19 +15,22 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
 
 import msgpack
 from unpaddedbase64 import decode_base64, encode_base64
 
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
 from synapse.api.errors import Codes, HttpResponseException
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.response_cache import ResponseCache
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 
 
 class RoomListHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.enable_room_list_search = hs.config.enable_room_list_search
-        self.response_cache = ResponseCache(hs, "room_list")
+        self.response_cache = ResponseCache(
+            hs, "room_list"
+        )  # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
         self.remote_response_cache = ResponseCache(
             hs, "remote_room_list", timeout_ms=30 * 1000
-        )
+        )  # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
 
     async def get_local_public_room_list(
         self,
-        limit=None,
-        since_token=None,
-        search_filter=None,
-        network_tuple=EMPTY_THIRD_PARTY_ID,
-        from_federation=False,
-    ):
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[dict] = None,
+        network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+        from_federation: bool = False,
+    ) -> JsonDict:
         """Generate a local public room list.
 
         There are multiple different lists: the main one plus one per third
         party network. A client can ask for a specific list or to return all.
 
         Args:
-            limit (int|None)
-            since_token (str|None)
-            search_filter (dict|None)
-            network_tuple (ThirdPartyInstanceID): Which public list to use.
+            limit
+            since_token
+            search_filter
+            network_tuple: Which public list to use.
                 This can be (None, None) to indicate the main list, or a particular
                 appservice and network id to use an appservice specific one.
                 Setting to None returns all public rooms across all lists.
-            from_federation (bool): true iff the request comes from the federation
-                API
+            from_federation: true iff the request comes from the federation API
         """
         if not self.enable_room_list_search:
             return {"chunk": [], "total_room_count_estimate": 0}
@@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler):
         self,
         limit: Optional[int] = None,
         since_token: Optional[str] = None,
-        search_filter: Optional[Dict] = None,
+        search_filter: Optional[dict] = None,
         network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
         from_federation: bool = False,
-    ) -> Dict[str, Any]:
+    ) -> JsonDict:
         """Generate a public room list.
         Args:
             limit: Maximum amount of rooms to return.
@@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler):
         if since_token:
             batch_token = RoomListNextBatch.from_token(since_token)
 
-            bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+            bounds = (
+                batch_token.last_joined_members,
+                batch_token.last_room_id,
+            )  # type: Optional[Tuple[int, str]]
             forwards = batch_token.direction_is_forward
+            has_batch_token = True
         else:
-            batch_token = None
             bounds = None
 
             forwards = True
+            has_batch_token = False
 
         # we request one more than wanted to see if there are more pages to come
         probing_limit = limit + 1 if limit is not None else None
@@ -159,7 +167,8 @@ class RoomListHandler(BaseHandler):
                 "canonical_alias": room["canonical_alias"],
                 "num_joined_members": room["joined_members"],
                 "avatar_url": room["avatar"],
-                "world_readable": room["history_visibility"] == "world_readable",
+                "world_readable": room["history_visibility"]
+                == HistoryVisibility.WORLD_READABLE,
                 "guest_can_join": room["guest_access"] == "can_join",
             }
 
@@ -168,7 +177,7 @@ class RoomListHandler(BaseHandler):
 
         results = [build_room_entry(r) for r in results]
 
-        response = {}
+        response = {}  # type: JsonDict
         num_results = len(results)
         if limit is not None:
             more_to_come = num_results == probing_limit
@@ -186,7 +195,7 @@ class RoomListHandler(BaseHandler):
             initial_entry = results[0]
 
             if forwards:
-                if batch_token:
+                if has_batch_token:
                     # If there was a token given then we assume that there
                     # must be previous results.
                     response["prev_batch"] = RoomListNextBatch(
@@ -202,7 +211,7 @@ class RoomListHandler(BaseHandler):
                         direction_is_forward=True,
                     ).to_token()
             else:
-                if batch_token:
+                if has_batch_token:
                     response["next_batch"] = RoomListNextBatch(
                         last_joined_members=final_entry["num_joined_members"],
                         last_room_id=final_entry["room_id"],
@@ -292,7 +301,7 @@ class RoomListHandler(BaseHandler):
                 return None
 
         # Return whether this room is open to federation users or not
-        create_event = current_state.get((EventTypes.Create, ""))
+        create_event = current_state[EventTypes.Create, ""]
         result["m.federate"] = create_event.content.get("m.federate", True)
 
         name_event = current_state.get((EventTypes.Name, ""))
@@ -317,7 +326,7 @@ class RoomListHandler(BaseHandler):
         visibility = None
         if visibility_event:
             visibility = visibility_event.content.get("history_visibility", None)
-        result["world_readable"] = visibility == "world_readable"
+        result["world_readable"] = visibility == HistoryVisibility.WORLD_READABLE
 
         guest_event = current_state.get((EventTypes.GuestAccess, ""))
         guest = None
@@ -335,13 +344,13 @@ class RoomListHandler(BaseHandler):
 
     async def get_remote_public_room_list(
         self,
-        server_name,
-        limit=None,
-        since_token=None,
-        search_filter=None,
-        include_all_networks=False,
-        third_party_instance_id=None,
-    ):
+        server_name: str,
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[dict] = None,
+        include_all_networks: bool = False,
+        third_party_instance_id: Optional[str] = None,
+    ) -> JsonDict:
         if not self.enable_room_list_search:
             return {"chunk": [], "total_room_count_estimate": 0}
 
@@ -398,13 +407,13 @@ class RoomListHandler(BaseHandler):
 
     async def _get_remote_list_cached(
         self,
-        server_name,
-        limit=None,
-        since_token=None,
-        search_filter=None,
-        include_all_networks=False,
-        third_party_instance_id=None,
-    ):
+        server_name: str,
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[dict] = None,
+        include_all_networks: bool = False,
+        third_party_instance_id: Optional[str] = None,
+    ) -> JsonDict:
         repl_layer = self.hs.get_federation_client()
         if search_filter:
             # We can't cache when asking for search
@@ -455,24 +464,24 @@ class RoomListNextBatch(
     REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
 
     @classmethod
-    def from_token(cls, token):
+    def from_token(cls, token: str) -> "RoomListNextBatch":
         decoded = msgpack.loads(decode_base64(token), raw=False)
         return RoomListNextBatch(
             **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
         )
 
-    def to_token(self):
+    def to_token(self) -> str:
         return encode_base64(
             msgpack.dumps(
                 {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
             )
         )
 
-    def copy_and_replace(self, **kwds):
+    def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
         return self._replace(**kwds)
 
 
-def _matches_room_entry(room_entry, search_filter):
+def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
     if search_filter and search_filter.get("generic_search_term", None):
         generic_search_term = search_filter["generic_search_term"].upper()
         if generic_search_term in room_entry.get("name", "").upper():
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c002886324..a5da97cfe0 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.registration_handler = hs.get_registration_handler()
         self.profile_handler = hs.get_profile_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self.account_data_handler = hs.get_account_data_handler()
 
         self.member_linearizer = Linearizer(name="member")
 
@@ -84,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
 
+        self._invites_per_room_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+        )
+        self._invites_per_user_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+        )
+
         # This is only used to get at ratelimit function, and
         # maybe_kick_guest_users. It's fine there are multiple of these as
         # it doesn't store state.
@@ -143,6 +155,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """
         raise NotImplementedError()
 
+    def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+        """Ratelimit invites by room and by target user.
+
+        If room ID is missing then we just rate limit by target user.
+        """
+        if room_id:
+            self._invites_per_room_limiter.ratelimit(room_id)
+
+        self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
     async def _local_membership_update(
         self,
         requester: Requester,
@@ -203,7 +225,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
             # Only rate-limit if the user actually joined the room, otherwise we'll end
             # up blocking profile updates.
-            if newly_joined:
+            if newly_joined and ratelimit:
                 time_now_s = self.clock.time()
                 (
                     allowed,
@@ -253,7 +275,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     direct_rooms[key].append(new_room_id)
 
                     # Save back to user's m.direct account data
-                    await self.store.add_account_data_for_user(
+                    await self.account_data_handler.add_account_data_for_user(
                         user_id, AccountDataTypes.DIRECT, direct_rooms
                     )
                     break
@@ -263,7 +285,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Copy each room tag to the new room
         for tag, tag_content in room_tags.items():
-            await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
+            await self.account_data_handler.add_tag_to_room(
+                user_id, new_room_id, tag, tag_content
+            )
 
     async def update_membership(
         self,
@@ -384,8 +408,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 raise SynapseError(403, "This room has been blocked on this server")
 
         if effective_membership_state == Membership.INVITE:
+            target_id = target.to_string()
+            if ratelimit:
+                # Don't ratelimit application services.
+                if not requester.app_service or requester.app_service.is_rate_limited():
+                    self.ratelimit_invite(room_id, target_id)
+
             # block any attempts to invite the server notices mxid
-            if target.to_string() == self._server_notices_mxid:
+            if target_id == self._server_notices_mxid:
                 raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
             block_invite = False
@@ -408,8 +438,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     )
                     block_invite = True
 
-                if not self.spam_checker.user_may_invite(
-                    requester.user.to_string(), target.to_string(), room_id
+                if not await self.spam_checker.user_may_invite(
+                    requester.user.to_string(), target_id, room_id
                 ):
                     logger.info("Blocking invite due to spam checker")
                     block_invite = True
@@ -488,17 +518,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     raise AuthError(403, "Guest access not allowed")
 
             if not is_host_in_room:
-                time_now_s = self.clock.time()
-                (
-                    allowed,
-                    time_allowed,
-                ) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
-
-                if not allowed:
-                    raise LimitExceededError(
-                        retry_after_ms=int(1000 * (time_allowed - time_now_s))
+                if ratelimit:
+                    time_now_s = self.clock.time()
+                    (
+                        allowed,
+                        time_allowed,
+                    ) = self._join_rate_limiter_remote.can_requester_do_action(
+                        requester,
                     )
 
+                    if not allowed:
+                        raise LimitExceededError(
+                            retry_after_ms=int(1000 * (time_allowed - time_now_s))
+                        )
+
                 inviter = await self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
                     remote_room_hosts.append(inviter.domain)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..e88fd59749 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
-from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
@@ -59,8 +58,6 @@ class SamlHandler(BaseHandler):
         super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._saml_idp_entityid = hs.config.saml2_idp_entityid
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
 
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
         self._grandfathered_mxid_source_attribute = (
@@ -76,30 +73,46 @@ class SamlHandler(BaseHandler):
         )
 
         # identifier for the external_ids table
-        self._auth_provider_id = "saml"
+        self.idp_id = "saml"
+
+        # user-facing name of this auth provider
+        self.idp_name = "SAML"
+
+        # we do not currently support icons/brands for SAML auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+        self.idp_brand = None
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
-        # a lock on the mappings
-        self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
         self._sso_handler = hs.get_sso_handler()
+        self._sso_handler.register_identity_provider(self)
 
-    def handle_redirect_request(
-        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
-    ) -> bytes:
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
+            request: the incoming HTTP request
             client_redirect_url: the URL that we should redirect the
-                client to when everything is done
+                client to after login (or None for UI Auth).
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
             URL to redirect to
         """
+        if not client_redirect_url:
+            # Some SAML identity providers (e.g. Google) require a
+            # RelayState parameter on requests, so pass in a dummy redirect URL
+            # (which will never get used).
+            client_redirect_url = b"unused"
+
         reqid, info = self._saml_client.prepare_for_authenticate(
             entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
@@ -120,7 +133,7 @@ class SamlHandler(BaseHandler):
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
     async def handle_saml_response(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_matrix/saml2/authn_response
+        """Handle an incoming request to /_synapse/client/saml2/authn_response
 
         Args:
             request: the incoming request from the browser. We'll
@@ -167,6 +180,29 @@ class SamlHandler(BaseHandler):
             return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+        await self._handle_authn_response(request, saml2_auth, relay_state)
+
+    async def _handle_authn_response(
+        self,
+        request: SynapseRequest,
+        saml2_auth: saml2.response.AuthnResponse,
+        relay_state: str,
+    ) -> None:
+        """Handle an AuthnResponse, having parsed it from the request params
+
+        Assumes that the signature on the response object has been checked. Maps
+        the user onto an MXID, registering them if necessary, and returns a response
+        to the browser.
+
+        Args:
+            request: the incoming request from the browser. We'll respond to it with an
+                HTML page or a redirect
+            saml2_auth: the parsed AuthnResponse object
+            relay_state: the RelayState query param, which encodes the URI to rediret
+               back to
+        """
+
         for assertion in saml2_auth.assertions:
             # kibana limits the length of a log field, whereas this is all rather
             # useful, so split it up.
@@ -183,6 +219,24 @@ class SamlHandler(BaseHandler):
             saml2_auth.in_response_to, None
         )
 
+        # first check if we're doing a UIA
+        if current_session and current_session.ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+            except MappingException as e:
+                logger.exception("Failed to extract remote user id from SAML response")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self.idp_id,
+                remote_user_id,
+                current_session.ui_auth_session_id,
+                request,
+            )
+
+        # otherwise, we're handling a login request.
+
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
         for requirement in self._saml2_attribute_requirements:
@@ -192,63 +246,39 @@ class SamlHandler(BaseHandler):
                 )
                 return
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
-
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_saml_response_to_user(
-                saml2_auth, relay_state, user_agent, ip_address
-            )
+            await self._complete_saml_login(saml2_auth, request, relay_state)
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
-            return
-
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
 
-    async def _map_saml_response_to_user(
+    async def _complete_saml_login(
         self,
         saml2_auth: saml2.response.AuthnResponse,
+        request: SynapseRequest,
         client_redirect_url: str,
-        user_agent: str,
-        ip_address: str,
-    ) -> str:
+    ) -> None:
         """
-        Given a SAML response, retrieve the user ID for it and possibly register the user.
+        Given a SAML response, complete the login flow
+
+        Retrieves the remote user ID, registers the user if necessary, and serves
+        a redirect back to the client with a login-token.
 
         Args:
             saml2_auth: The parsed SAML2 response.
+            request: The request to respond to
             client_redirect_url: The redirect URL passed in by the client.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Returns:
-             The user ID associated with this response.
 
         Raises:
             MappingException if there was a problem mapping the response to a user.
             RedirectException: some mapping providers may raise this if they need
                 to redirect to an interstitial page.
         """
-
-        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+        remote_user_id = self._remote_id_from_saml_response(
             saml2_auth, client_redirect_url
         )
 
-        if not remote_user_id:
-            raise MappingException(
-                "Failed to extract remote user id from SAML response"
-            )
-
         async def saml_response_to_remapped_user_attributes(
             failures: int,
         ) -> UserAttributes:
@@ -294,16 +324,44 @@ class SamlHandler(BaseHandler):
 
             return None
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            return await self._sso_handler.get_mxid_from_sso(
-                self._auth_provider_id,
-                remote_user_id,
-                user_agent,
-                ip_address,
-                saml_response_to_remapped_user_attributes,
-                grandfather_existing_users,
+        await self._sso_handler.complete_sso_login_request(
+            self.idp_id,
+            remote_user_id,
+            request,
+            client_redirect_url,
+            saml_response_to_remapped_user_attributes,
+            grandfather_existing_users,
+        )
+
+    def _remote_id_from_saml_response(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
+    ) -> str:
+        """Extract the unique remote id from a SAML2 AuthnResponse
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+        Returns:
+            remote user id
+
+        Raises:
+            MappingException if there was an error extracting the user id
+        """
+        # It's not obvious why we need to pass in the redirect URI to the mapping
+        # provider, but we do :/
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+            saml2_auth, client_redirect_url
+        )
+
+        if not remote_user_id:
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
             )
 
+        return remote_user_id
+
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
 
 import itertools
 import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
 
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
+from synapse.events import EventBase
 from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SearchHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
 
         return historical_room_ids
 
-    async def search(self, user, content, batch=None):
+    async def search(
+        self, user: UserID, content: JsonDict, batch: Optional[str] = None
+    ) -> JsonDict:
         """Performs a full text search for a user.
 
         Args:
-            user (UserID)
-            content (dict): Search parameters
-            batch (str): The next_batch parameter. Used for pagination.
+            user
+            content: Search parameters
+            batch: The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
         if search_filter.rooms:
-            historical_room_ids = []
+            historical_room_ids = []  # type: List[str]
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
                 ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
 
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
-        room_groups = {}  # Holds result of grouping by room, if applicable
-        sender_group = {}  # Holds result of grouping by sender, if applicable
+        # Holds result of grouping by room, if applicable
+        room_groups = {}  # type: Dict[str, JsonDict]
+        # Holds result of grouping by sender, if applicable
+        sender_group = {}  # type: Dict[str, JsonDict]
 
         # Holds the next_batch for the entire result set if one of those exists
         global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
-            room_events = []
+            room_events = []  # type: List[EventBase]
             i = 0
 
             pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
 
         state_results = {}
         if include_state:
-            rooms = {e.room_id for e in allowed_events}
-            for room_id in rooms:
+            for room_id in {e.room_id for e in allowed_events}:
                 state = await self.state_handler.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
-            state_results.values()
-
         # We're now about to serialize the events. We should not make any
         # blocking calls after this. Otherwise the 'age' will be wrong
 
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
 
         if state_results:
             s = {}
-            for room_id, state in state_results.items():
+            for room_id, state_events in state_results.items():
                 s[room_id] = await self._event_serializer.serialize_events(
-                    state, time_now
+                    state_events, time_now
                 )
 
             rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..84af2dde7e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,24 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.types import Requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
-        self._password_policy_handler = hs.get_password_policy_handler()
 
     async def set_password(
         self,
@@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
         password_hash: str,
         logout_devices: bool,
         requester: Optional[Requester] = None,
-    ):
+    ) -> None:
         if not self.hs.config.password_localdb_enabled:
             raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..b450668f1c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,35 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Mapping,
+    Optional,
+    Set,
+)
+from urllib.parse import urlencode
 
 import attr
+from typing_extensions import NoReturn, Protocol
 
-from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
-from synapse.http.server import respond_with_html
-from synapse.types import UserID, contains_invalid_mxid_characters
+from twisted.web.http import Request
+from twisted.web.iweb import IRequest
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+from synapse.http import get_request_user_agent
+from synapse.http.server import respond_with_html, respond_with_redirect
+from synapse.http.site import SynapseRequest
+from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
+from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -35,24 +55,190 @@ class MappingException(Exception):
     """
 
 
+class SsoIdentityProvider(Protocol):
+    """Abstract base class to be implemented by SSO Identity Providers
+
+    An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+    to say whether they should be allowed to log in, or perform a given action.
+
+    Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+    and CAS.
+
+    The main entry point is `handle_redirect_request`, which should return a URI to
+    redirect the user's browser to the IdP's authentication page.
+
+    Each IdP should be registered with the SsoHandler via
+    `hs.get_sso_handler().register_identity_provider()`, so that requests to
+    `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+    """
+
+    @property
+    @abc.abstractmethod
+    def idp_id(self) -> str:
+        """A unique identifier for this SSO provider
+
+        Eg, "saml", "cas", "github"
+        """
+
+    @property
+    @abc.abstractmethod
+    def idp_name(self) -> str:
+        """User-facing name for this provider"""
+
+    @property
+    def idp_icon(self) -> Optional[str]:
+        """Optional MXC URI for user-facing icon"""
+        return None
+
+    @property
+    def idp_brand(self) -> Optional[str]:
+        """Optional branding identifier"""
+        return None
+
+    @abc.abstractmethod
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Handle an incoming request to /login/sso/redirect
+
+        Args:
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            URL to redirect to
+        """
+        raise NotImplementedError()
+
+
 @attr.s
 class UserAttributes:
-    localpart = attr.ib(type=str)
+    # the localpart of the mxid that the mapper has assigned to the user.
+    # if `None`, the mapper has not picked a userid, and the user should be prompted to
+    # enter one.
+    localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
-    emails = attr.ib(type=List[str], default=attr.Factory(list))
+    emails = attr.ib(type=Collection[str], default=attr.Factory(list))
+
+
+@attr.s(slots=True)
+class UsernameMappingSession:
+    """Data we track about SSO sessions"""
+
+    # A unique identifier for this SSO provider, e.g.  "oidc" or "saml".
+    auth_provider_id = attr.ib(type=str)
+
+    # user ID on the IdP server
+    remote_user_id = attr.ib(type=str)
 
+    # attributes returned by the ID mapper
+    display_name = attr.ib(type=Optional[str])
+    emails = attr.ib(type=Collection[str])
 
-class SsoHandler(BaseHandler):
+    # An optional dictionary of extra attributes to be provided to the client in the
+    # login response.
+    extra_login_attributes = attr.ib(type=Optional[JsonDict])
+
+    # where to redirect the client back to
+    client_redirect_url = attr.ib(type=str)
+
+    # expiry time for the session, in milliseconds
+    expiry_time_ms = attr.ib(type=int)
+
+    # choices made by the user
+    chosen_localpart = attr.ib(type=Optional[str], default=None)
+    use_display_name = attr.ib(type=bool, default=True)
+    emails_to_use = attr.ib(type=Collection[str], default=())
+    terms_accepted_version = attr.ib(type=Optional[str], default=None)
+
+
+# the HTTP cookie used to track the mapping session id
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
+
+
+class SsoHandler:
     # The number of attempts to ask the mapping provider for when generating an MXID.
     _MAP_USERNAME_RETRIES = 1000
 
+    # the time a UsernameMappingSession remains valid for
+    _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
+
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        self._clock = hs.get_clock()
+        self._store = hs.get_datastore()
+        self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
+        self._auth_handler = hs.get_auth_handler()
         self._error_template = hs.config.sso_error_template
+        self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+        # The following template is shown after a successful user interactive
+        # authentication session. It tells the user they can close the window.
+        self._sso_auth_success_template = hs.config.sso_auth_success_template
+
+        # a lock on the mappings
+        self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
+
+        # a map from session id to session data
+        self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
+
+        # map from idp_id to SsoIdentityProvider
+        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+
+        self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
+    def register_identity_provider(self, p: SsoIdentityProvider):
+        p_id = p.idp_id
+        assert p_id not in self._identity_providers
+        self._identity_providers[p_id] = p
+
+    def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
+        """Get the configured identity providers"""
+        return self._identity_providers
+
+    async def get_identity_providers_for_user(
+        self, user_id: str
+    ) -> Mapping[str, SsoIdentityProvider]:
+        """Get the SsoIdentityProviders which a user has used
+
+        Given a user id, get the identity providers that that user has used to log in
+        with in the past (and thus could use to re-identify themselves for UI Auth).
+
+        Args:
+            user_id: MXID of user to look up
+
+        Raises:
+            a map of idp_id to SsoIdentityProvider
+        """
+        external_ids = await self._store.get_external_ids_by_user(user_id)
+
+        valid_idps = {}
+        for idp_id, _ in external_ids:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                logger.warning(
+                    "User %r has an SSO mapping for IdP %r, but this is no longer "
+                    "configured.",
+                    user_id,
+                    idp_id,
+                )
+            else:
+                valid_idps[idp_id] = idp
+
+        return valid_idps
 
     def render_error(
-        self, request, error: str, error_description: Optional[str] = None
+        self,
+        request: Request,
+        error: str,
+        error_description: Optional[str] = None,
+        code: int = 400,
     ) -> None:
         """Renders the error template and responds with it.
 
@@ -64,11 +250,53 @@ class SsoHandler(BaseHandler):
                 We'll respond with an HTML page describing the error.
             error: A technical identifier for this error.
             error_description: A human-readable description of the error.
+            code: The integer error code (an HTTP response code)
         """
         html = self._error_template.render(
             error=error, error_description=error_description
         )
-        respond_with_html(request, 400, html)
+        respond_with_html(request, code, html)
+
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: bytes,
+        idp_id: Optional[str],
+    ) -> str:
+        """Handle a request to /login/sso/redirect
+
+        Args:
+            request: incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login.
+            idp_id: optional identity provider chosen by the client
+
+        Returns:
+             the URI to redirect to
+        """
+        if not self._identity_providers:
+            raise SynapseError(
+                400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+            )
+
+        # if the client chose an IdP, use that
+        idp = None  # type: Optional[SsoIdentityProvider]
+        if idp_id:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                raise NotFoundError("Unknown identity provider")
+
+        # if we only have one auth provider, redirect to it directly
+        elif len(self._identity_providers) == 1:
+            idp = next(iter(self._identity_providers.values()))
+
+        if idp:
+            return await idp.handle_redirect_request(request, client_redirect_url)
+
+        # otherwise, redirect to the IDP picker
+        return "/_synapse/client/pick_idp?" + urlencode(
+            (("redirectUrl", client_redirect_url),)
+        )
 
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str
@@ -95,7 +323,7 @@ class SsoHandler(BaseHandler):
         )
 
         # Check if we already have a mapping for this user.
-        previously_registered_user_id = await self.store.get_user_by_external_id(
+        previously_registered_user_id = await self._store.get_user_by_external_id(
             auth_provider_id, remote_user_id,
         )
 
@@ -112,15 +340,16 @@ class SsoHandler(BaseHandler):
         # No match.
         return None
 
-    async def get_mxid_from_sso(
+    async def complete_sso_login_request(
         self,
         auth_provider_id: str,
         remote_user_id: str,
-        user_agent: str,
-        ip_address: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
-        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
-    ) -> str:
+        grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
+        extra_login_attributes: Optional[JsonDict] = None,
+    ) -> None:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
 
@@ -139,12 +368,18 @@ class SsoHandler(BaseHandler):
         given user-agent and IP address and the SSO ID is linked to this matrix
         ID for subsequent calls.
 
+        Finally, we generate a redirect to the supplied redirect uri, with a login token
+
         Args:
             auth_provider_id: A unique identifier for this SSO provider, e.g.
                 "oidc" or "saml".
+
             remote_user_id: The unique identifier from the SSO provider.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
+
+            request: The request to respond to
+
+            client_redirect_url: The redirect URL passed in by the client.
+
             sso_to_matrix_id_mapper: A callable to generate the user attributes.
                 The only parameter is an integer which represents the amount of
                 times the returned mxid localpart mapping has failed.
@@ -156,12 +391,13 @@ class SsoHandler(BaseHandler):
                         to the user.
                     RedirectException to redirect to an additional page (e.g.
                         to prompt the user for more information).
+
             grandfather_existing_users: A callable which can return an previously
                 existing matrix ID. The SSO ID is then linked to the returned
                 matrix ID.
 
-        Returns:
-             The user ID associated with the SSO response.
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
 
         Raises:
             MappingException if there was a problem mapping the response to a user.
@@ -169,24 +405,62 @@ class SsoHandler(BaseHandler):
                 to an additional page. (e.g. to prompt for more information)
 
         """
-        # first of all, check if we already have a mapping for this user
-        previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
-            auth_provider_id, remote_user_id,
-        )
-        if previously_registered_user_id:
-            return previously_registered_user_id
+        new_user = False
+
+        # grab a lock while we try to find a mapping for this user. This seems...
+        # optimistic, especially for implementations that end up redirecting to
+        # interstitial pages.
+        with await self._mapping_lock.queue(auth_provider_id):
+            # first of all, check if we already have a mapping for this user
+            user_id = await self.get_sso_user_by_remote_user_id(
+                auth_provider_id, remote_user_id,
+            )
 
-        # Check for grandfathering of users.
-        if grandfather_existing_users:
-            previously_registered_user_id = await grandfather_existing_users()
-            if previously_registered_user_id:
-                # Future logins should also match this user ID.
-                await self.store.record_user_external_id(
-                    auth_provider_id, remote_user_id, previously_registered_user_id
+            # Check for grandfathering of users.
+            if not user_id:
+                user_id = await grandfather_existing_users()
+                if user_id:
+                    # Future logins should also match this user ID.
+                    await self._store.record_user_external_id(
+                        auth_provider_id, remote_user_id, user_id
+                    )
+
+            # Otherwise, generate a new user.
+            if not user_id:
+                attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+
+                if attributes.localpart is None:
+                    # the mapper doesn't return a username. bail out with a redirect to
+                    # the username picker.
+                    await self._redirect_to_username_picker(
+                        auth_provider_id,
+                        remote_user_id,
+                        attributes,
+                        client_redirect_url,
+                        extra_login_attributes,
+                    )
+
+                user_id = await self._register_mapped_user(
+                    attributes,
+                    auth_provider_id,
+                    remote_user_id,
+                    get_request_user_agent(request),
+                    request.getClientIP(),
                 )
-                return previously_registered_user_id
+                new_user = True
+
+        await self._auth_handler.complete_sso_login(
+            user_id,
+            request,
+            client_redirect_url,
+            extra_login_attributes,
+            new_user=new_user,
+        )
 
-        # Otherwise, generate a new user.
+    async def _call_attribute_mapper(
+        self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+    ) -> UserAttributes:
+        """Call the attribute mapper function in a loop, until we get a unique userid"""
         for i in range(self._MAP_USERNAME_RETRIES):
             try:
                 attributes = await sso_to_matrix_id_mapper(i)
@@ -208,14 +482,12 @@ class SsoHandler(BaseHandler):
             )
 
             if not attributes.localpart:
-                raise MappingException(
-                    "Error parsing SSO response: SSO mapping provider plugin "
-                    "did not return a localpart value"
-                )
+                # the mapper has not picked a localpart
+                return attributes
 
             # Check if this mxid already exists
-            user_id = UserID(attributes.localpart, self.server_name).to_string()
-            if not await self.store.get_users_by_id_case_insensitive(user_id):
+            user_id = UserID(attributes.localpart, self._server_name).to_string()
+            if not await self._store.get_users_by_id_case_insensitive(user_id):
                 # This mxid is free
                 break
         else:
@@ -224,10 +496,101 @@ class SsoHandler(BaseHandler):
             raise MappingException(
                 "Unable to generate a Matrix ID from the SSO response"
             )
+        return attributes
+
+    async def _redirect_to_username_picker(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        attributes: UserAttributes,
+        client_redirect_url: str,
+        extra_login_attributes: Optional[JsonDict],
+    ) -> NoReturn:
+        """Creates a UsernameMappingSession and redirects the browser
+
+        Called if the user mapping provider doesn't return a localpart for a new user.
+        Raises a RedirectException which redirects the browser to the username picker.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            attributes: the user attributes returned by the user mapping provider.
+
+            client_redirect_url: The redirect URL passed in by the client, which we
+                will eventually redirect back to.
+
+            extra_login_attributes: An optional dictionary of extra
+                attributes to be provided to the client in the login response.
+
+        Raises:
+            RedirectException
+        """
+        session_id = random_string(16)
+        now = self._clock.time_msec()
+        session = UsernameMappingSession(
+            auth_provider_id=auth_provider_id,
+            remote_user_id=remote_user_id,
+            display_name=attributes.display_name,
+            emails=attributes.emails,
+            client_redirect_url=client_redirect_url,
+            expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
+            extra_login_attributes=extra_login_attributes,
+        )
+
+        self._username_mapping_sessions[session_id] = session
+        logger.info("Recorded registration session id %s", session_id)
+
+        # Set the cookie and redirect to the username picker
+        e = RedirectException(b"/_synapse/client/pick_username/account_details")
+        e.cookies.append(
+            b"%s=%s; path=/"
+            % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
+        )
+        raise e
+
+    async def _register_mapped_user(
+        self,
+        attributes: UserAttributes,
+        auth_provider_id: str,
+        remote_user_id: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """Register a new SSO user.
+
+        This is called once we have successfully mapped the remote user id onto a local
+        user id, one way or another.
+
+        Args:
+             attributes: user attributes returned by the user mapping provider,
+                including a non-empty localpart.
+
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+
+            remote_user_id: The unique identifier from the SSO provider.
+
+            user_agent: The user-agent in the HTTP request (used for potential
+                shadow-banning.)
+
+            ip_address: The IP address of the requester (used for potential
+                shadow-banning.)
+
+        Raises:
+            a MappingException if the localpart is invalid.
+
+            a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
+            is already taken.
+        """
 
         # Since the localpart is provided via a potentially untrusted module,
         # ensure the MXID is valid before registering.
-        if contains_invalid_mxid_characters(attributes.localpart):
+        if not attributes.localpart or contains_invalid_mxid_characters(
+            attributes.localpart
+        ):
             raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
 
         logger.debug("Mapped SSO user to local part %s", attributes.localpart)
@@ -238,7 +601,292 @@ class SsoHandler(BaseHandler):
             user_agent_ips=[(user_agent, ip_address)],
         )
 
-        await self.store.record_user_external_id(
+        await self._store.record_user_external_id(
             auth_provider_id, remote_user_id, registered_user_id
         )
         return registered_user_id
+
+    async def complete_sso_ui_auth_request(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        ui_auth_session_id: str,
+        request: Request,
+    ) -> None:
+        """
+        Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+        Note that this requires that the user is mapped in the "user_external_ids"
+        table. This will be the case if they have ever logged in via SAML or OIDC in
+        recentish synapse versions, but may not be for older users.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The unique identifier from the SSO provider.
+            ui_auth_session_id: The ID of the user-interactive auth session.
+            request: The request to complete.
+        """
+
+        user_id = await self.get_sso_user_by_remote_user_id(
+            auth_provider_id, remote_user_id,
+        )
+
+        user_id_to_verify = await self._auth_handler.get_session_data(
+            ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        if not user_id:
+            logger.warning(
+                "Remote user %s/%s has not previously logged in here: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+            )
+        elif user_id != user_id_to_verify:
+            logger.warning(
+                "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+                user_id,
+            )
+        else:
+            # success!
+            # Mark the stage of the authentication as successful.
+            await self._store.mark_ui_auth_stage_complete(
+                ui_auth_session_id, LoginType.SSO, user_id
+            )
+
+            # Render the HTML confirmation page and return.
+            html = self._sso_auth_success_template
+            respond_with_html(request, 200, html)
+            return
+
+        # the user_id didn't match: mark the stage of the authentication as unsuccessful
+        await self._store.mark_ui_auth_stage_complete(
+            ui_auth_session_id, LoginType.SSO, ""
+        )
+
+        # render an error page.
+        html = self._bad_user_template.render(
+            server_name=self._server_name, user_id_to_verify=user_id_to_verify,
+        )
+        respond_with_html(request, 200, html)
+
+    def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+        """Look up the given username mapping session
+
+        If it is not found, raises a SynapseError with an http code of 400
+
+        Args:
+            session_id: session to look up
+        Returns:
+            active mapping session
+        Raises:
+            SynapseError if the session is not found/has expired
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if session:
+            return session
+        logger.info("Couldn't find session id %s", session_id)
+        raise SynapseError(400, "unknown session")
+
+    async def check_username_availability(
+        self, localpart: str, session_id: str,
+    ) -> bool:
+        """Handle an "is username available" callback check
+
+        Args:
+            localpart: desired localpart
+            session_id: the session id for the username picker
+        Returns:
+            True if the username is available
+        Raises:
+            SynapseError if the localpart is invalid or the session is unknown
+        """
+
+        # make sure that there is a valid mapping session, to stop people dictionary-
+        # scanning for accounts
+        self.get_mapping_session(session_id)
+
+        logger.info(
+            "[session %s] Checking for availability of username %s",
+            session_id,
+            localpart,
+        )
+
+        if contains_invalid_mxid_characters(localpart):
+            raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
+        user_id = UserID(localpart, self._server_name).to_string()
+        user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
+
+        logger.info("[session %s] users: %s", session_id, user_infos)
+        return not user_infos
+
+    async def handle_submit_username_request(
+        self,
+        request: SynapseRequest,
+        session_id: str,
+        localpart: str,
+        use_display_name: bool,
+        emails_to_use: Iterable[str],
+    ) -> None:
+        """Handle a request to the username-picker 'submit' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            localpart: localpart requested by the user
+            session_id: ID of the username mapping session, extracted from a cookie
+            use_display_name: whether the user wants to use the suggested display name
+            emails_to_use: emails that the user would like to use
+        """
+        session = self.get_mapping_session(session_id)
+
+        # update the session with the user's choices
+        session.chosen_localpart = localpart
+        session.use_display_name = use_display_name
+
+        emails_from_idp = set(session.emails)
+        filtered_emails = set()  # type: Set[str]
+
+        # we iterate through the list rather than just building a set conjunction, so
+        # that we can log attempts to use unknown addresses
+        for email in emails_to_use:
+            if email in emails_from_idp:
+                filtered_emails.add(email)
+            else:
+                logger.warning(
+                    "[session %s] ignoring user request to use unknown email address %r",
+                    session_id,
+                    email,
+                )
+        session.emails_to_use = filtered_emails
+
+        # we may now need to collect consent from the user, in which case, redirect
+        # to the consent-extraction-unit
+        if self._consent_at_registration:
+            redirect_url = b"/_synapse/client/new_user_consent"
+
+        # otherwise, redirect to the completion page
+        else:
+            redirect_url = b"/_synapse/client/sso_register"
+
+        respond_with_redirect(request, redirect_url)
+
+    async def handle_terms_accepted(
+        self, request: Request, session_id: str, terms_version: str
+    ):
+        """Handle a request to the new-user 'consent' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+            terms_version: the version of the terms which the user viewed and consented
+                to
+        """
+        logger.info(
+            "[session %s] User consented to terms version %s",
+            session_id,
+            terms_version,
+        )
+        session = self.get_mapping_session(session_id)
+        session.terms_accepted_version = terms_version
+
+        # we're done; now we can register the user
+        respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+    async def register_sso_user(self, request: Request, session_id: str) -> None:
+        """Called once we have all the info we need to register a new user.
+
+        Does so and serves an HTTP response
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        session = self.get_mapping_session(session_id)
+
+        logger.info(
+            "[session %s] Registering localpart %s",
+            session_id,
+            session.chosen_localpart,
+        )
+
+        attributes = UserAttributes(
+            localpart=session.chosen_localpart, emails=session.emails_to_use,
+        )
+
+        if session.use_display_name:
+            attributes.display_name = session.display_name
+
+        # the following will raise a 400 error if the username has been taken in the
+        # meantime.
+        user_id = await self._register_mapped_user(
+            attributes,
+            session.auth_provider_id,
+            session.remote_user_id,
+            get_request_user_agent(request),
+            request.getClientIP(),
+        )
+
+        logger.info(
+            "[session %s] Registered userid %s with attributes %s",
+            session_id,
+            user_id,
+            attributes,
+        )
+
+        # delete the mapping session and the cookie
+        del self._username_mapping_sessions[session_id]
+
+        # delete the cookie
+        request.addCookie(
+            USERNAME_MAPPING_SESSION_COOKIE_NAME,
+            b"",
+            expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
+            path=b"/",
+        )
+
+        auth_result = {}
+        if session.terms_accepted_version:
+            # TODO: make this less awful.
+            auth_result[LoginType.TERMS] = True
+
+        await self._registration_handler.post_registration_actions(
+            user_id, auth_result, access_token=None
+        )
+
+        await self._auth_handler.complete_sso_login(
+            user_id,
+            request,
+            session.client_redirect_url,
+            session.extra_login_attributes,
+            new_user=True,
+        )
+
+    def _expire_old_sessions(self):
+        to_expire = []
+        now = int(self._clock.time_msec())
+
+        for session_id, session in self._username_mapping_sessions.items():
+            if session.expiry_time_ms <= now:
+                to_expire.append(session_id)
+
+        for session_id in to_expire:
+            logger.info("Expiring mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+    """Extract the session ID from the cookie
+
+    Raises a SynapseError if the cookie isn't found
+    """
+    session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+    if not session_id:
+        raise SynapseError(code=400, msg="missing session_id")
+    return session_id.decode("ascii", errors="replace")
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class StateDeltasHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+    async def _get_key_change(
+        self,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+        key_name: str,
+        public_value: str,
+    ) -> Optional[bool]:
         """Given two events check if the `key_name` field in content changed
         from not matching `public_value` to doing so.
 
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..d261d7cd4e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -12,13 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +37,7 @@ class StatsHandler:
     Heavily derived from UserDirectoryHandler
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
@@ -44,7 +50,7 @@ class StatsHandler:
         self.stats_enabled = hs.config.stats_enabled
 
         # The current position in the current_state_delta stream
-        self.pos = None
+        self.pos = None  # type: Optional[int]
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -56,7 +62,7 @@ class StatsHandler:
             # we start populating stats
             self.clock.call_later(0, self.notify_new_event)
 
-    def notify_new_event(self):
+    def notify_new_event(self) -> None:
         """Called when there may be more deltas to process
         """
         if not self.stats_enabled or self._is_processing:
@@ -72,7 +78,7 @@ class StatsHandler:
 
         run_as_background_process("stats.notify_new_event", process)
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
             self.pos = await self.store.get_stats_positions()
@@ -110,10 +116,10 @@ class StatsHandler:
             )
 
             for room_id, fields in room_count.items():
-                room_deltas.setdefault(room_id, {}).update(fields)
+                room_deltas.setdefault(room_id, Counter()).update(fields)
 
             for user_id, fields in user_count.items():
-                user_deltas.setdefault(user_id, {}).update(fields)
+                user_deltas.setdefault(user_id, Counter()).update(fields)
 
             logger.debug("room_deltas: %s", room_deltas)
             logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +137,20 @@ class StatsHandler:
 
             self.pos = max_pos
 
-    async def _handle_deltas(self, deltas):
+    async def _handle_deltas(
+        self, deltas: Iterable[JsonDict]
+    ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
         """Called with the state deltas to process
 
         Returns:
-            tuple[dict[str, Counter], dict[str, counter]]
             Two dicts: the room deltas and the user deltas,
             mapping from room/user ID to changes in the various fields.
         """
 
-        room_to_stats_deltas = {}
-        user_to_stats_deltas = {}
+        room_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
+        user_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
 
-        room_to_state_updates = {}
+        room_to_state_updates = {}  # type: Dict[str, Dict[str, Any]]
 
         for delta in deltas:
             typ = delta["type"]
@@ -173,7 +180,7 @@ class StatsHandler:
                 )
                 continue
 
-            event_content = {}
+            event_content = {}  # type: JsonDict
 
             sender = None
             if event_id is not None:
@@ -257,13 +264,13 @@ class StatsHandler:
                     )
 
                     if has_changed_joinedness:
-                        delta = +1 if membership == Membership.JOIN else -1
+                        membership_delta = +1 if membership == Membership.JOIN else -1
 
                         user_to_stats_deltas.setdefault(user_id, Counter())[
                             "joined_rooms"
-                        ] += delta
+                        ] += membership_delta
 
-                        room_stats_delta["local_users_in_room"] += delta
+                        room_stats_delta["local_users_in_room"] += membership_delta
 
             elif typ == EventTypes.Create:
                 room_state["is_federatable"] = (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9827c7eb8d..5c7590f38e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -554,7 +554,7 @@ class SyncHandler:
             event.event_id, state_filter=state_filter
         )
         if event.is_state():
-            state_ids = state_ids.copy()
+            state_ids = dict(state_ids)
             state_ids[(event.type, event.state_key)] = event.event_id
         return state_ids
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
 import logging
 import random
 from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}
+        self._room_serials = {}  # type: Dict[str, int]
         # map room IDs to sets of users currently typing
-        self._room_typing = {}
+        self._room_typing = {}  # type: Dict[str, Set[str]]
 
-        self._member_last_federation_poke = {}
+        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
-    def _reset(self):
+    def _reset(self) -> None:
         """Reset the typing handler's data caches.
         """
         # map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
         self._member_last_federation_poke = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
-    def _handle_timeouts(self):
+    def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
         for member in members:
             self._handle_timeout_for_member(now, member)
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         if not self.is_typing(member):
             # Nothing to do if they're no longer typing
             return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
         # each person typing.
         self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
-    def is_typing(self, member):
+    def is_typing(self, member: RoomMember) -> bool:
         return member.user_id in self._room_typing.get(member.room_id, [])
 
-    async def _push_remote(self, member, typing):
+    async def _push_remote(self, member: RoomMember, typing: bool) -> None:
         if not self.federation:
             return
 
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         """Should be called whenever we receive updates for typing stream.
         """
 
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
 
     async def _send_changes_in_typing_to_remotes(
         self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
-    ):
+    ) -> None:
         """Process a change in typing of a room from replication, sending EDUs
         for any local users.
         """
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
             if self.is_mine_id(user_id):
                 await self._push_remote(RoomMember(room_id, user_id), False)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         return self._latest_room_serial
 
 
 class TypingWriterHandler(FollowerTypingHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-        self._member_typing_until = {}  # clock time we expect to stop
+        # clock time we expect to stop
+        self._member_typing_until = {}  # type: Dict[RoomMember, int]
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
             "TypingStreamChangeCache", self._latest_room_serial
         )
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         super()._handle_timeout_for_member(now, member)
 
         if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             self._stopped_typing(member)
             return
 
-    async def started_typing(self, target_user, requester, room_id, timeout):
+    async def started_typing(
+        self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         if was_present:
             # No point sending another notification
-            return None
+            return
 
         self._push_update(member=member, typing=True)
 
-    async def stopped_typing(self, target_user, requester, room_id):
+    async def stopped_typing(
+        self, target_user: UserID, requester: Requester, room_id: str
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._stopped_typing(member)
 
-    def user_left_room(self, user, room_id):
+    def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         if self.is_mine_id(user_id):
             member = RoomMember(room_id=room_id, user_id=user_id)
             self._stopped_typing(member)
 
-    def _stopped_typing(self, member):
+    def _stopped_typing(self, member: RoomMember) -> None:
         if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
-            return None
+            return
 
         self._member_typing_until.pop(member, None)
         self._member_last_federation_poke.pop(member, None)
 
         self._push_update(member=member, typing=False)
 
-    def _push_update(self, member, typing):
+    def _push_update(self, member: RoomMember, typing: bool) -> None:
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
             run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._push_update_local(member=member, typing=typing)
 
-    async def _recv_edu(self, origin, content):
+    async def _recv_edu(self, origin: str, content: JsonDict) -> None:
         room_id = content["room_id"]
         user_id = content["user_id"]
 
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
             self._push_update_local(member=member, typing=content["typing"])
 
-    def _push_update_local(self, member, typing):
+    def _push_update_local(self, member: RoomMember, typing: bool) -> None:
         room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
             room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
-        )
+        )  # type: Optional[Iterable[str]]
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         # The writing process should never get updates from replication.
         raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         # We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
         #
         self.get_typing_handler = hs.get_typing_handler
 
-    def _make_event_for(self, room_id):
+    def _make_event_for(self, room_id: str) -> JsonDict:
         typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: Iterable[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    def get_current_key(self):
+    def get_current_key(self) -> int:
         return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
 """
 
 from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS  # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+    """Constants for use with AuthHandler.set_session_data"""
+
+    # used during registration and password reset to store a hashed copy of the
+    # password, so that the client does not need to submit it each time.
+    PASSWORD_HASH = "password_hash"
+
+    # used during registration to store the mxid of the registered user
+    REGISTERED_USER_ID = "registered_user_id"
+
+    # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+    # for.
+    REQUEST_USER_ID = "request_user_id"
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc200..8aedf5072e 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,14 +14,19 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
 import synapse.metrics
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
 from synapse.handlers.state_deltas import StateDeltasHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.roommember import ProfileInfo
+from synapse.types import JsonDict
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     be in the directory or not when necessary.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.store = hs.get_datastore()
@@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         self.search_all_users = hs.config.user_directory_search_all_users
         self.spam_checker = hs.get_spam_checker()
         # The current position in the current_state_delta stream
-        self.pos = None
+        self.pos = None  # type: Optional[int]
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler):
             # we start populating the user directory
             self.clock.call_later(0, self.notify_new_event)
 
-    async def search_users(self, user_id, search_term, limit):
+    async def search_users(
+        self, user_id: str, search_term: str, limit: int
+    ) -> JsonDict:
         """Searches for users in directory
 
         Returns:
@@ -81,15 +88,15 @@ class UserDirectoryHandler(StateDeltasHandler):
         results = await self.store.search_user_dir(user_id, search_term, limit)
 
         # Remove any spammy users from the results.
-        results["results"] = [
-            user
-            for user in results["results"]
-            if not self.spam_checker.check_username_for_spam(user)
-        ]
+        non_spammy_users = []
+        for user in results["results"]:
+            if not await self.spam_checker.check_username_for_spam(user):
+                non_spammy_users.append(user)
+        results["results"] = non_spammy_users
 
         return results
 
-    def notify_new_event(self):
+    def notify_new_event(self) -> None:
         """Called when there may be more deltas to process
         """
         if not self.update_user_directory:
@@ -107,35 +114,37 @@ class UserDirectoryHandler(StateDeltasHandler):
         self._is_processing = True
         run_as_background_process("user_directory.notify_new_event", process)
 
-    async def handle_local_profile_change(self, user_id, profile):
+    async def handle_local_profile_change(
+        self, user_id: str, profile: ProfileInfo
+    ) -> None:
         """Called to update index of our local user profiles when they change
         irrespective of any rooms the user may be in.
         """
         # FIXME(#3714): We should probably do this in the same worker as all
         # the other changes.
-        is_support = await self.store.is_support_user(user_id)
+
         # Support users are for diagnostics and should not appear in the user directory.
-        if not is_support:
+        is_support = await self.store.is_support_user(user_id)
+        # When change profile information of deactivated user it should not appear in the user directory.
+        is_deactivated = await self.store.get_user_deactivated_status(user_id)
+
+        if not (is_support or is_deactivated):
             await self.store.update_profile_in_user_dir(
                 user_id, profile.display_name, profile.avatar_url
             )
 
-    async def handle_user_deactivated(self, user_id):
+    async def handle_user_deactivated(self, user_id: str) -> None:
         """Called when a user ID is deactivated
         """
         # FIXME(#3714): We should probably do this in the same worker as all
         # the other changes.
         await self.store.remove_from_user_dir(user_id)
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
             self.pos = await self.store.get_user_directory_stream_pos()
 
-        # If still None then the initial background update hasn't happened yet
-        if self.pos is None:
-            return None
-
         # Loop round handling deltas until we're up to date
         while True:
             with Measure(self.clock, "user_dir_delta"):
@@ -162,7 +171,7 @@ class UserDirectoryHandler(StateDeltasHandler):
 
                 await self.store.update_user_directory_stream_pos(max_pos)
 
-    async def _handle_deltas(self, deltas):
+    async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
         """Called with the state deltas to process
         """
         for delta in deltas:
@@ -220,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
 
                     if change:  # The user joined
                         event = await self.store.get_event(event_id, allow_none=True)
+                        # It isn't expected for this event to not exist, but we
+                        # don't want the entire background process to break.
+                        if event is None:
+                            continue
+
                         profile = ProfileInfo(
                             avatar_url=event.content.get("avatar_url"),
                             display_name=event.content.get("displayname"),
@@ -232,16 +246,20 @@ class UserDirectoryHandler(StateDeltasHandler):
                 logger.debug("Ignoring irrelevant type: %r", typ)
 
     async def _handle_room_publicity_change(
-        self, room_id, prev_event_id, event_id, typ
-    ):
+        self,
+        room_id: str,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+        typ: str,
+    ) -> None:
         """Handle a room having potentially changed from/to world_readable/publicly
         joinable.
 
         Args:
-            room_id (str)
-            prev_event_id (str|None): The previous event before the state change
-            event_id (str|None): The new event after the state change
-            typ (str): Type of the event
+            room_id: The ID of the room which changed.
+            prev_event_id: The previous event before the state change
+            event_id: The new event after the state change
+            typ: Type of the event
         """
         logger.debug("Handling change for %s: %s", typ, room_id)
 
@@ -250,7 +268,7 @@ class UserDirectoryHandler(StateDeltasHandler):
                 prev_event_id,
                 event_id,
                 key_name="history_visibility",
-                public_value="world_readable",
+                public_value=HistoryVisibility.WORLD_READABLE,
             )
         elif typ == EventTypes.JoinRules:
             change = await self._get_key_change(
@@ -299,12 +317,14 @@ class UserDirectoryHandler(StateDeltasHandler):
         for user_id, profile in users_with_profile.items():
             await self._handle_new_user(room_id, user_id, profile)
 
-    async def _handle_new_user(self, room_id, user_id, profile):
+    async def _handle_new_user(
+        self, room_id: str, user_id: str, profile: ProfileInfo
+    ) -> None:
         """Called when we might need to add user to directory
 
         Args:
-            room_id (str): room_id that user joined or started being public
-            user_id (str)
+            room_id: The room ID that user joined or started being public
+            user_id
         """
         logger.debug("Adding new user to dir, %r", user_id)
 
@@ -352,12 +372,12 @@ class UserDirectoryHandler(StateDeltasHandler):
             if to_insert:
                 await self.store.add_users_who_share_private_room(room_id, to_insert)
 
-    async def _handle_remove_user(self, room_id, user_id):
+    async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
         """Called when we might need to remove user from directory
 
         Args:
-            room_id (str): room_id that user left or stopped being public that
-            user_id (str)
+            room_id: The room ID that user left or stopped being public that
+            user_id
         """
         logger.debug("Removing user %r", user_id)
 
@@ -370,7 +390,13 @@ class UserDirectoryHandler(StateDeltasHandler):
         if len(rooms_user_is_in) == 0:
             await self.store.remove_from_user_dir(user_id)
 
-    async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+    async def _handle_profile_change(
+        self,
+        user_id: str,
+        room_id: str,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+    ) -> None:
         """Check member event changes for any profile changes and update the
         database if there are.
         """
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 59b01b812c..4bc3cb53f0 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -17,6 +17,7 @@ import re
 
 from twisted.internet import task
 from twisted.web.client import FileBodyProducer
+from twisted.web.iweb import IRequest
 
 from synapse.api.errors import SynapseError
 
@@ -50,3 +51,17 @@ class QuieterFileBodyProducer(FileBodyProducer):
             FileBodyProducer.stopProducing(self)
         except task.TaskStopped:
             pass
+
+
+def get_request_user_agent(request: IRequest, default: str = "") -> str:
+    """Return the last User-Agent header, or the given default.
+    """
+    # There could be raw utf-8 bytes in the User-Agent header.
+
+    # N.B. if you don't do this, the logger explodes cryptically
+    # with maximum recursion trying to log errors about
+    # the charset problem.
+    # c.f. https://github.com/matrix-org/synapse/issues/3471
+
+    h = request.getHeader(b"User-Agent")
+    return h.decode("ascii", "replace") if h else default
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e5b13593f2..37ccf5ab98 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -32,7 +32,7 @@ from typing import (
 
 import treq
 from canonicaljson import encode_canonical_json
-from netaddr import IPAddress, IPSet
+from netaddr import AddrFormatError, IPAddress, IPSet
 from prometheus_client import Counter
 from zope.interface import implementer, provider
 
@@ -125,7 +125,7 @@ def _make_scheduler(reactor):
     return _scheduler
 
 
-class IPBlacklistingResolver:
+class _IPBlacklistingResolver:
     """
     A proxy for reactor.nameResolver which only produces non-blacklisted IP
     addresses, preventing DNS rebinding attacks on URL preview.
@@ -199,6 +199,35 @@ class IPBlacklistingResolver:
         return r
 
 
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+    """
+    A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+    addresses, to prevent DNS rebinding.
+    """
+
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
+        self._reactor = reactor
+
+        # We need to use a DNS resolver which filters out blacklisted IP
+        # addresses, to prevent DNS rebinding.
+        self._nameResolver = _IPBlacklistingResolver(
+            self._reactor, ip_whitelist, ip_blacklist
+        )
+
+    def __getattr__(self, attr: str) -> Any:
+        # Passthrough to the real reactor except for the DNS resolver.
+        if attr == "nameResolver":
+            return self._nameResolver
+        else:
+            return getattr(self._reactor, attr)
+
+
 class BlacklistingAgentWrapper(Agent):
     """
     An Agent wrapper which will prevent access to IP addresses being accessed
@@ -232,16 +261,16 @@ class BlacklistingAgentWrapper(Agent):
 
         try:
             ip_address = IPAddress(h.hostname)
-
+        except AddrFormatError:
+            # Not an IP
+            pass
+        else:
             if check_against_blacklist(
                 ip_address, self._ip_whitelist, self._ip_blacklist
             ):
                 logger.info("Blocking access to %s due to blacklist" % (ip_address,))
                 e = SynapseError(403, "IP address blocked by IP blacklist entry")
                 return defer.fail(Failure(e))
-        except Exception:
-            # Not an IP
-            pass
 
         return self._agent.request(
             method, uri, headers=headers, bodyProducer=bodyProducer
@@ -292,22 +321,11 @@ class SimpleHttpClient:
         self.user_agent = self.user_agent.encode("ascii")
 
         if self._ip_blacklist:
-            real_reactor = hs.get_reactor()
             # If we have an IP blacklist, we need to use a DNS resolver which
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
-            nameResolver = IPBlacklistingResolver(
-                real_reactor, self._ip_whitelist, self._ip_blacklist
+            self.reactor = BlacklistingReactorWrapper(
+                hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
             )
-
-            @implementer(IReactorPluggableNameResolver)
-            class Reactor:
-                def __getattr__(_self, attr):
-                    if attr == "nameResolver":
-                        return nameResolver
-                    else:
-                        return getattr(real_reactor, attr)
-
-            self.reactor = Reactor()
         else:
             self.reactor = hs.get_reactor()
 
@@ -323,6 +341,7 @@ class SimpleHttpClient:
 
         self.agent = ProxyAgent(
             self.reactor,
+            hs.get_reactor(),
             connectTimeout=15,
             contextFactory=self.hs.get_http_client_context_factory(),
             pool=pool,
@@ -702,11 +721,14 @@ class SimpleHttpClient:
 
         try:
             length = await make_deferred_yieldable(
-                readBodyToFile(response, output_stream, max_size)
+                read_body_with_max_size(response, output_stream, max_size)
+            )
+        except BodyExceededMaxSize:
+            raise SynapseError(
+                502,
+                "Requested file is too large > %r bytes" % (max_size,),
+                Codes.TOO_LARGE,
             )
-        except SynapseError:
-            # This can happen e.g. because the body is too large.
-            raise
         except Exception as e:
             raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
 
@@ -730,7 +752,11 @@ def _timeout_to_request_timed_out_error(f: Failure):
     return f
 
 
-class _ReadBodyToFileProtocol(protocol.Protocol):
+class BodyExceededMaxSize(Exception):
+    """The maximum allowed size of the HTTP body was exceeded."""
+
+
+class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
     def __init__(
         self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
     ):
@@ -740,20 +766,24 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
         self.max_size = max_size
 
     def dataReceived(self, data: bytes) -> None:
+        # If the deferred was called, bail early.
+        if self.deferred.called:
+            return
+
         self.stream.write(data)
         self.length += len(data)
+        # The first time the maximum size is exceeded, error and cancel the
+        # connection. dataReceived might be called again if data was received
+        # in the meantime.
         if self.max_size is not None and self.length >= self.max_size:
-            self.deferred.errback(
-                SynapseError(
-                    502,
-                    "Requested file is too large > %r bytes" % (self.max_size,),
-                    Codes.TOO_LARGE,
-                )
-            )
-            self.deferred = defer.Deferred()
+            self.deferred.errback(BodyExceededMaxSize())
             self.transport.loseConnection()
 
     def connectionLost(self, reason: Failure) -> None:
+        # If the maximum size was already exceeded, there's nothing to do.
+        if self.deferred.called:
+            return
+
         if reason.check(ResponseDone):
             self.deferred.callback(self.length)
         elif reason.check(PotentialDataLoss):
@@ -764,12 +794,15 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
             self.deferred.errback(reason)
 
 
-def readBodyToFile(
+def read_body_with_max_size(
     response: IResponse, stream: BinaryIO, max_size: Optional[int]
 ) -> defer.Deferred:
     """
     Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
 
+    If the maximum file size is reached, the returned Deferred will resolve to a
+    Failure with a BodyExceededMaxSize exception.
+
     Args:
         response: The HTTP response to read from.
         stream: The file-object to write to.
@@ -780,7 +813,7 @@ def readBodyToFile(
     """
 
     d = defer.Deferred()
-    response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
+    response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
     return d
 
 
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
deleted file mode 100644
index 92a5b606c8..0000000000
--- a/synapse/http/endpoint.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import re
-
-logger = logging.getLogger(__name__)
-
-
-def parse_server_name(server_name):
-    """Split a server name into host/port parts.
-
-    Args:
-        server_name (str): server name to parse
-
-    Returns:
-        Tuple[str, int|None]: host/port parts.
-
-    Raises:
-        ValueError if the server name could not be parsed.
-    """
-    try:
-        if server_name[-1] == "]":
-            # ipv6 literal, hopefully
-            return server_name, None
-
-        domain_port = server_name.rsplit(":", 1)
-        domain = domain_port[0]
-        port = int(domain_port[1]) if domain_port[1:] else None
-        return domain, port
-    except Exception:
-        raise ValueError("Invalid server name '%s'" % server_name)
-
-
-VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
-
-
-def parse_and_validate_server_name(server_name):
-    """Split a server name into host/port parts and do some basic validation.
-
-    Args:
-        server_name (str): server name to parse
-
-    Returns:
-        Tuple[str, int|None]: host/port parts.
-
-    Raises:
-        ValueError if the server name could not be parsed.
-    """
-    host, port = parse_server_name(server_name)
-
-    # these tests don't need to be bulletproof as we'll find out soon enough
-    # if somebody is giving us invalid data. What we *do* need is to be sure
-    # that nobody is sneaking IP literals in that look like hostnames, etc.
-
-    # look for ipv6 literals
-    if host[0] == "[":
-        if host[-1] != "]":
-            raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
-        return host, port
-
-    # otherwise it should only be alphanumerics.
-    if not VALID_HOST_REGEX.match(host):
-        raise ValueError(
-            "Server name '%s' contains invalid characters" % (server_name,)
-        )
-
-    return host, port
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index e77f9587d0..4c06a117d3 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -16,7 +16,7 @@ import logging
 import urllib.parse
 from typing import List, Optional
 
-from netaddr import AddrFormatError, IPAddress
+from netaddr import AddrFormatError, IPAddress, IPSet
 from zope.interface import implementer
 
 from twisted.internet import defer
@@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
 
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.http.client import BlacklistingAgentWrapper
 from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.http.federation.well_known_resolver import WellKnownResolver
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -70,6 +71,7 @@ class MatrixFederationAgent:
         reactor: IReactorCore,
         tls_client_options_factory: Optional[FederationPolicyForHTTPS],
         user_agent: bytes,
+        ip_blacklist: IPSet,
         _srv_resolver: Optional[SrvResolver] = None,
         _well_known_resolver: Optional[WellKnownResolver] = None,
     ):
@@ -90,12 +92,17 @@ class MatrixFederationAgent:
         self.user_agent = user_agent
 
         if _well_known_resolver is None:
+            # Note that the name resolver has already been wrapped in a
+            # IPBlacklistingResolver by MatrixFederationHttpClient.
             _well_known_resolver = WellKnownResolver(
                 self._reactor,
-                agent=Agent(
-                    self._reactor,
-                    pool=self._pool,
-                    contextFactory=tls_client_options_factory,
+                agent=BlacklistingAgentWrapper(
+                    Agent(
+                        self._reactor,
+                        pool=self._pool,
+                        contextFactory=tls_client_options_factory,
+                    ),
+                    ip_blacklist=ip_blacklist,
                 ),
                 user_agent=self.user_agent,
             )
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 5e08ef1664..b3b6dbcab0 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -15,17 +15,19 @@
 import logging
 import random
 import time
+from io import BytesIO
 from typing import Callable, Dict, Optional, Tuple
 
 import attr
 
 from twisted.internet import defer
 from twisted.internet.interfaces import IReactorTime
-from twisted.web.client import RedirectAgent, readBody
+from twisted.web.client import RedirectAgent
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent, IResponse
 
+from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock, json_decoder
 from synapse.util.caches.ttlcache import TTLCache
@@ -53,6 +55,9 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 # lower bound for .well-known cache period
 WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
 
+# The maximum size (in bytes) to allow a well-known file to be.
+WELL_KNOWN_MAX_SIZE = 50 * 1024  # 50 KiB
+
 # Attempt to refetch a cached well-known N% of the TTL before it expires.
 # e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
 # we'll start trying to refetch 1 minute before it expires.
@@ -229,6 +234,9 @@ class WellKnownResolver:
             server_name: name of the server, from the requested url
             retry: Whether to retry the request if it fails.
 
+        Raises:
+            _FetchWellKnownFailure if we fail to lookup a result
+
         Returns:
             Returns the response object and body. Response may be a non-200 response.
         """
@@ -250,7 +258,11 @@ class WellKnownResolver:
                         b"GET", uri, headers=Headers(headers)
                     )
                 )
-                body = await make_deferred_yieldable(readBody(response))
+                body_stream = BytesIO()
+                await make_deferred_yieldable(
+                    read_body_with_max_size(response, body_stream, WELL_KNOWN_MAX_SIZE)
+                )
+                body = body_stream.getvalue()
 
                 if 500 <= response.code < 600:
                     raise Exception("Non-200 response %s" % (response.code,))
@@ -259,6 +271,15 @@ class WellKnownResolver:
             except defer.CancelledError:
                 # Bail if we've been cancelled
                 raise
+            except BodyExceededMaxSize:
+                # If the well-known file was too large, do not keep attempting
+                # to download it, but consider it a temporary error.
+                logger.warning(
+                    "Requested .well-known file for %s is too large > %r bytes",
+                    server_name.decode("ascii"),
+                    WELL_KNOWN_MAX_SIZE,
+                )
+                raise _FetchWellKnownFailure(temporary=True)
             except Exception as e:
                 if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
                     logger.info("Error fetching %s: %s", uri_str, e)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 4e27f93b7a..19293bf673 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,11 +26,10 @@ import treq
 from canonicaljson import encode_canonical_json
 from prometheus_client import Counter
 from signedjson.sign import sign_json
-from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
+from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import _EPSILON, Cooperator
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
@@ -38,16 +37,19 @@ from twisted.web.iweb import IBodyProducer, IResponse
 import synapse.metrics
 import synapse.util.retryutils
 from synapse.api.errors import (
+    Codes,
     FederationDeniedError,
     HttpResponseException,
     RequestSendFailed,
+    SynapseError,
 )
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
     BlacklistingAgentWrapper,
-    IPBlacklistingResolver,
+    BlacklistingReactorWrapper,
+    BodyExceededMaxSize,
     encode_query_args,
-    readBodyToFile,
+    read_body_with_max_size,
 )
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
 from synapse.logging.context import make_deferred_yieldable
@@ -172,6 +174,16 @@ async def _handle_json_response(
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
         body = await make_deferred_yieldable(d)
+    except ValueError as e:
+        # The JSON content was invalid.
+        logger.warning(
+            "{%s} [%s] Failed to parse JSON response - %s %s",
+            request.txn_id,
+            request.destination,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        raise RequestSendFailed(e, can_retry=False) from e
     except defer.TimeoutError as e:
         logger.warning(
             "{%s} [%s] Timed out reading response - %s %s",
@@ -221,31 +233,22 @@ class MatrixFederationHttpClient:
         self.signing_key = hs.signing_key
         self.server_name = hs.hostname
 
-        real_reactor = hs.get_reactor()
-
         # We need to use a DNS resolver which filters out blacklisted IP
         # addresses, to prevent DNS rebinding.
-        nameResolver = IPBlacklistingResolver(
-            real_reactor, None, hs.config.federation_ip_range_blacklist
+        self.reactor = BlacklistingReactorWrapper(
+            hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
         )
 
-        @implementer(IReactorPluggableNameResolver)
-        class Reactor:
-            def __getattr__(_self, attr):
-                if attr == "nameResolver":
-                    return nameResolver
-                else:
-                    return getattr(real_reactor, attr)
-
-        self.reactor = Reactor()
-
         user_agent = hs.version_string
         if hs.config.user_agent_suffix:
             user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
         user_agent = user_agent.encode("ascii")
 
         self.agent = MatrixFederationAgent(
-            self.reactor, tls_client_options_factory, user_agent
+            self.reactor,
+            tls_client_options_factory,
+            user_agent,
+            hs.config.federation_ip_range_blacklist,
         )
 
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
@@ -985,9 +988,15 @@ class MatrixFederationHttpClient:
         headers = dict(response.headers.getAllRawHeaders())
 
         try:
-            d = readBodyToFile(response, output_stream, max_size)
+            d = read_body_with_max_size(response, output_stream, max_size)
             d.addTimeout(self.default_timeout, self.reactor)
             length = await make_deferred_yieldable(d)
+        except BodyExceededMaxSize:
+            msg = "Requested file is too large > %r bytes" % (max_size,)
+            logger.warning(
+                "{%s} [%s] %s", request.txn_id, request.destination, msg,
+            )
+            raise SynapseError(502, msg, Codes.TOO_LARGE)
         except Exception as e:
             logger.warning(
                 "{%s} [%s] Error reading response: %s",
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index e32d3f43e0..b730d2c634 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -39,6 +39,10 @@ class ProxyAgent(_AgentBase):
         reactor: twisted reactor to place outgoing
             connections.
 
+        proxy_reactor: twisted reactor to use for connections to the proxy server
+                       reactor might have some blacklisting applied (i.e. for DNS queries),
+                       but we need unblocked access to the proxy.
+
         contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
             verification parameters of OpenSSL.  The default is to use a
             `BrowserLikePolicyForHTTPS`, so unless you have special
@@ -59,6 +63,7 @@ class ProxyAgent(_AgentBase):
     def __init__(
         self,
         reactor,
+        proxy_reactor=None,
         contextFactory=BrowserLikePolicyForHTTPS(),
         connectTimeout=None,
         bindAddress=None,
@@ -68,6 +73,11 @@ class ProxyAgent(_AgentBase):
     ):
         _AgentBase.__init__(self, reactor, pool)
 
+        if proxy_reactor is None:
+            self.proxy_reactor = reactor
+        else:
+            self.proxy_reactor = proxy_reactor
+
         self._endpoint_kwargs = {}
         if connectTimeout is not None:
             self._endpoint_kwargs["timeout"] = connectTimeout
@@ -75,11 +85,11 @@ class ProxyAgent(_AgentBase):
             self._endpoint_kwargs["bindAddress"] = bindAddress
 
         self.http_proxy_endpoint = _http_proxy_endpoint(
-            http_proxy, reactor, **self._endpoint_kwargs
+            http_proxy, self.proxy_reactor, **self._endpoint_kwargs
         )
 
         self.https_proxy_endpoint = _http_proxy_endpoint(
-            https_proxy, reactor, **self._endpoint_kwargs
+            https_proxy, self.proxy_reactor, **self._endpoint_kwargs
         )
 
         self._policy_for_https = contextFactory
@@ -137,7 +147,7 @@ class ProxyAgent(_AgentBase):
             request_path = uri
         elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
             endpoint = HTTPConnectProxyEndpoint(
-                self._reactor,
+                self.proxy_reactor,
                 self.https_proxy_endpoint,
                 parsed_uri.host,
                 parsed_uri.port,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6a4e429a6c..8249732b27 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Pattern,
+    Tuple,
+    Union,
+)
 
 import jinja2
 from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
 from zope.interface import implementer
 
 from twisted.internet import defer, interfaces
@@ -168,11 +180,25 @@ def wrap_async_request_handler(h):
     return preserve_fn(wrapped_async_request_handler)
 
 
-class HttpServer:
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+    ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
+
+
+class HttpServer(Protocol):
     """ Interface for registering callbacks on a HTTP server
     """
 
-    def register_paths(self, method, path_patterns, callback):
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
         """ Register a callback that gets fired if we receive a http request
         with the given method for a path that matches the given regex.
 
@@ -180,12 +206,14 @@ class HttpServer:
         an unpacked tuple.
 
         Args:
-            method (str): The method to listen to.
-            path_patterns (list<SRE_Pattern>): The regex used to match requests.
-            callback (function): The function to fire if we receive a matched
+            method: The HTTP method to listen to.
+            path_patterns: The regex used to match requests.
+            callback: The function to fire if we receive a matched
                 request. The first argument will be the request object and
                 subsequent arguments will be any matched groups from the regex.
-                This should return a tuple of (code, response).
+                This should return either tuple of (code, response), or None.
+            servlet_classname (str): The name of the handler to be used in prometheus
+                and opentracing logs.
         """
         pass
 
@@ -275,6 +303,10 @@ class DirectServeJsonResource(_AsyncResource):
     formatting responses and errors as JSON.
     """
 
+    def __init__(self, canonical_json=False, extract_context=False):
+        super().__init__(extract_context)
+        self.canonical_json = canonical_json
+
     def _send_response(
         self, request: Request, code: int, response_object: Any,
     ):
@@ -318,9 +350,7 @@ class JsonResource(DirectServeJsonResource):
     )
 
     def __init__(self, hs, canonical_json=True, extract_context=False):
-        super().__init__(extract_context)
-
-        self.canonical_json = canonical_json
+        super().__init__(canonical_json, extract_context)
         self.clock = hs.get_clock()
         self.path_regexs = {}
         self.hs = hs
@@ -352,7 +382,7 @@ class JsonResource(DirectServeJsonResource):
 
     def _get_handler_for_request(
         self, request: SynapseRequest
-    ) -> Tuple[Callable, str, Dict[str, str]]:
+    ) -> Tuple[ServletCallback, str, Dict[str, str]]:
         """Finds a callback method to handle the given request.
 
         Returns:
@@ -731,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request):
     request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
 
 
+def respond_with_redirect(request: Request, url: bytes) -> None:
+    """Write a 302 response to the request, if it is still alive."""
+    logger.debug("Redirect to %s", url.decode("utf-8"))
+    request.redirect(url)
+    finish_request(request)
+
+
 def finish_request(request: Request):
     """ Finish writing the response to the request.
 
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc3f..12ec3f851f 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
 
 from synapse.config.server import ListenerConfig
-from synapse.http import redact_uri
+from synapse.http import get_request_user_agent, redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.types import Requester
@@ -113,23 +113,13 @@ class SynapseRequest(Request):
             method = self.method.decode("ascii")
         return method
 
-    def get_user_agent(self, default: str) -> str:
-        """Return the last User-Agent header, or the given default.
-        """
-        user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
-        if user_agent is None:
-            return default
-
-        return user_agent.decode("ascii", "replace")
-
     def render(self, resrc):
         # this is called once a Resource has been found to serve the request; in our
         # case the Resource in question will normally be a JsonResource.
 
         # create a LogContext for this request
         request_id = self.get_request_id()
-        logcontext = self.logcontext = LoggingContext(request_id)
-        logcontext.request = request_id
+        self.logcontext = LoggingContext(request_id, request=request_id)
 
         # override the Server header which is set by twisted
         self.setHeader("Server", self.site.server_version_string)
@@ -293,12 +283,7 @@ class SynapseRequest(Request):
             # and can see that we're doing something wrong.
             authenticated_entity = repr(self.requester)  # type: ignore[unreachable]
 
-        # ...or could be raw utf-8 bytes in the User-Agent header.
-        # N.B. if you don't do this, the logger explodes cryptically
-        # with maximum recursion trying to log errors about
-        # the charset problem.
-        # c.f. https://github.com/matrix-org/synapse/issues/3471
-        user_agent = self.get_user_agent("-")
+        user_agent = get_request_user_agent(self, "-")
 
         code = str(self.code)
         if not self.finished:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..c2db8b45f3 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel:
     def copy_to(self, record):
         pass
 
-    def copy_to_twisted_log_entry(self, record):
-        record["request"] = None
-        record["scope"] = None
-
     def start(self, rusage: "Optional[resource._RUsage]"):
         pass
 
@@ -256,7 +252,12 @@ class LoggingContext:
         "scope",
     ]
 
-    def __init__(self, name=None, parent_context=None, request=None) -> None:
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        parent_context: "Optional[LoggingContext]" = None,
+        request: Optional[str] = None,
+    ) -> None:
         self.previous_context = current_context()
         self.name = name
 
@@ -372,13 +373,6 @@ class LoggingContext:
         # we also track the current scope:
         record.scope = self.scope
 
-    def copy_to_twisted_log_entry(self, record) -> None:
-        """
-        Copy logging fields from this context to a Twisted log record.
-        """
-        record["request"] = self.request
-        record["scope"] = self.scope
-
     def start(self, rusage: "Optional[resource._RUsage]") -> None:
         """
         Record that this logcontext is currently running.
@@ -542,28 +536,25 @@ class LoggingContext:
 class LoggingContextFilter(logging.Filter):
     """Logging filter that adds values from the current logging context to each
     record.
-    Args:
-        **defaults: Default values to avoid formatters complaining about
-            missing fields
     """
 
-    def __init__(self, **defaults) -> None:
-        self.defaults = defaults
+    def __init__(self, request: str = ""):
+        self._default_request = request
 
-    def filter(self, record) -> Literal[True]:
+    def filter(self, record: logging.LogRecord) -> Literal[True]:
         """Add each fields from the logging contexts to the record.
         Returns:
             True to include the record in the log output.
         """
         context = current_context()
-        for key, value in self.defaults.items():
-            setattr(record, key, value)
+        record.request = self._default_request  # type: ignore
 
         # context should never be None, but if it somehow ends up being, then
         # we end up in a death spiral of infinite loops, so let's check, for
         # robustness' sake.
         if context is not None:
-            context.copy_to(record)
+            # Logging is interested in the request.
+            record.request = context.request  # type: ignore
 
         return True
 
@@ -630,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
     return current
 
 
-def nested_logging_context(
-    suffix: str, parent_context: Optional[LoggingContext] = None
-) -> LoggingContext:
+def nested_logging_context(suffix: str) -> LoggingContext:
     """Creates a new logging context as a child of another.
 
     The nested logging context will have a 'request' made up of the parent context's
@@ -646,20 +635,23 @@ def nested_logging_context(
             # ... do stuff
 
     Args:
-        suffix (str): suffix to add to the parent context's 'request'.
-        parent_context (LoggingContext|None): parent context. Will use the current context
-            if None.
+        suffix: suffix to add to the parent context's 'request'.
 
     Returns:
         LoggingContext: new logging context.
     """
-    if parent_context is not None:
-        context = parent_context  # type: LoggingContextOrSentinel
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Starting nested logging context from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+        prefix = ""
     else:
-        context = current_context()
-    return LoggingContext(
-        parent_context=context, request=str(context.request) + "-" + suffix
-    )
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
+        prefix = str(parent_context.request)
+    return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
 
 
 def preserve_fn(f):
@@ -836,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = current_context()
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Calling defer_to_threadpool from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+    else:
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
 
     def g():
-        with LoggingContext(parent_context=logcontext):
+        with LoggingContext(parent_context=parent_context):
             return f(*args, **kwargs)
 
     return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..0538350f38 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -791,7 +791,7 @@ def tag_args(func):
 
     @wraps(func)
     def _tag_args_inner(*args, **kwargs):
-        argspec = inspect.getargspec(func)
+        argspec = inspect.getfullargspec(func)
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])
         set_tag("args", args[len(argspec.args) :])
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 658f6ecd72..70e0fa45d9 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
 import logging
 import threading
 from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     import resource
@@ -199,19 +199,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
         _background_process_start_count.labels(desc).inc()
         _background_process_in_flight_count.labels(desc).inc()
 
-        with BackgroundProcessLoggingContext(desc) as context:
-            context.request = "%s-%i" % (desc, count)
+        with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
             try:
                 ctx = noop_context_manager()
                 if bg_start_span:
                     ctx = start_active_span(desc, tags={"request_id": context.request})
                 with ctx:
-                    result = func(*args, **kwargs)
-
-                    if inspect.isawaitable(result):
-                        result = await result
-
-                    return result
+                    return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
                 logger.exception(
                     "Background process '%s' threw an exception", desc,
@@ -249,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
 
     __slots__ = ["_proc"]
 
-    def __init__(self, name: str):
-        super().__init__(name)
+    def __init__(self, name: str, request: Optional[str] = None):
+        super().__init__(name, request=request)
 
         self._proc = _BackgroundProcess(name, self)
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 72ab5750cc..401d577293 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -279,7 +279,11 @@ class ModuleApi:
         )
 
     async def complete_sso_login_async(
-        self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+        self,
+        registered_user_id: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
+        new_user: bool = False,
     ):
         """Complete a SSO login by redirecting the user to a page to confirm whether they
         want their access token sent to `client_redirect_url`, or redirect them to that
@@ -291,9 +295,11 @@ class ModuleApi:
             request: The request to respond to.
             client_redirect_url: The URL to which to offer to redirect the user (or to
                 redirect them directly if whitelisted).
+            new_user: set to true to use wording for the consent appropriate to a user
+                who has just registered.
         """
         await self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url,
+            registered_user_id, request, client_redirect_url, new_user=new_user
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/notifier.py b/synapse/notifier.py
index a17352ef46..0745899b48 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -34,7 +34,7 @@ from prometheus_client import Counter
 from twisted.internet import defer
 
 import synapse.server
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.api.errors import AuthError
 from synapse.events import EventBase
 from synapse.handlers.presence import format_user_presence_state
@@ -396,31 +396,30 @@ class Notifier:
 
         Will wake up all listeners for the given users and rooms.
         """
-        with PreserveLoggingContext():
-            with Measure(self.clock, "on_new_event"):
-                user_streams = set()
+        with Measure(self.clock, "on_new_event"):
+            user_streams = set()
 
-                for user in users:
-                    user_stream = self.user_to_user_stream.get(str(user))
-                    if user_stream is not None:
-                        user_streams.add(user_stream)
+            for user in users:
+                user_stream = self.user_to_user_stream.get(str(user))
+                if user_stream is not None:
+                    user_streams.add(user_stream)
 
-                for room in rooms:
-                    user_streams |= self.room_to_user_streams.get(room, set())
+            for room in rooms:
+                user_streams |= self.room_to_user_streams.get(room, set())
 
-                time_now_ms = self.clock.time_msec()
-                for user_stream in user_streams:
-                    try:
-                        user_stream.notify(stream_key, new_token, time_now_ms)
-                    except Exception:
-                        logger.exception("Failed to notify listener")
+            time_now_ms = self.clock.time_msec()
+            for user_stream in user_streams:
+                try:
+                    user_stream.notify(stream_key, new_token, time_now_ms)
+                except Exception:
+                    logger.exception("Failed to notify listener")
 
-                self.notify_replication()
+            self.notify_replication()
 
-                # Notify appservices
-                self._notify_app_services_ephemeral(
-                    stream_key, new_token, users,
-                )
+            # Notify appservices
+            self._notify_app_services_ephemeral(
+                stream_key, new_token, users,
+            )
 
     def on_new_replication_data(self) -> None:
         """Used to inform replication listeners that something has happened
@@ -611,7 +610,9 @@ class Notifier:
             room_id, EventTypes.RoomHistoryVisibility, ""
         )
         if state and "history_visibility" in state.content:
-            return state.content["history_visibility"] == "world_readable"
+            return (
+                state.content["history_visibility"] == HistoryVisibility.WORLD_READABLE
+            )
         else:
             return False
 
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5a437f9810..f4f7ec96f8 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -13,7 +13,111 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+import attr
+
+from synapse.types import JsonDict, RoomStreamToken
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
+@attr.s(slots=True)
+class PusherConfig:
+    """Parameters necessary to configure a pusher."""
+
+    id = attr.ib(type=Optional[str])
+    user_name = attr.ib(type=str)
+    access_token = attr.ib(type=Optional[int])
+    profile_tag = attr.ib(type=str)
+    kind = attr.ib(type=str)
+    app_id = attr.ib(type=str)
+    app_display_name = attr.ib(type=str)
+    device_display_name = attr.ib(type=str)
+    pushkey = attr.ib(type=str)
+    ts = attr.ib(type=int)
+    lang = attr.ib(type=Optional[str])
+    data = attr.ib(type=Optional[JsonDict])
+    last_stream_ordering = attr.ib(type=int)
+    last_success = attr.ib(type=Optional[int])
+    failing_since = attr.ib(type=Optional[int])
+
+    def as_dict(self) -> Dict[str, Any]:
+        """Information that can be retrieved about a pusher after creation."""
+        return {
+            "app_display_name": self.app_display_name,
+            "app_id": self.app_id,
+            "data": self.data,
+            "device_display_name": self.device_display_name,
+            "kind": self.kind,
+            "lang": self.lang,
+            "profile_tag": self.profile_tag,
+            "pushkey": self.pushkey,
+        }
+
+
+@attr.s(slots=True)
+class ThrottleParams:
+    """Parameters for controlling the rate of sending pushes via email."""
+
+    last_sent_ts = attr.ib(type=int)
+    throttle_ms = attr.ib(type=int)
+
+
+class Pusher(metaclass=abc.ABCMeta):
+    def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+        self.hs = hs
+        self.store = self.hs.get_datastore()
+        self.clock = self.hs.get_clock()
+
+        self.pusher_id = pusher_config.id
+        self.user_id = pusher_config.user_name
+        self.app_id = pusher_config.app_id
+        self.pushkey = pusher_config.pushkey
+
+        self.last_stream_ordering = pusher_config.last_stream_ordering
+
+        # This is the highest stream ordering we know it's safe to process.
+        # When new events arrive, we'll be given a window of new events: we
+        # should honour this rather than just looking for anything higher
+        # because of potential out-of-order event serialisation.
+        self.max_stream_ordering = self.store.get_room_max_stream_ordering()
+
+    def on_new_notifications(self, max_token: RoomStreamToken) -> None:
+        # We just use the minimum stream ordering and ignore the vector clock
+        # component. This is safe to do as long as we *always* ignore the vector
+        # clock components.
+        max_stream_ordering = max_token.stream
+
+        self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+        self._start_processing()
+
+    @abc.abstractmethod
+    def _start_processing(self):
+        """Start processing push notifications."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_started(self, have_notifs: bool) -> None:
+        """Called when this pusher has been started.
+
+        Args:
+            should_check_for_notifs: Whether we should immediately
+                check for push to send. Set to False only if it's known there
+                is nothing to send
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_stop(self) -> None:
+        raise NotImplementedError()
+
 
 class PusherConfigException(Exception):
-    def __init__(self, msg):
-        super().__init__(msg)
+    """An error occurred when creating a pusher."""
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index fabc9ba126..aaed28650d 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -14,19 +14,22 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
 from synapse.util.metrics import Measure
 
-from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class ActionGenerator:
-    def __init__(self, hs):
-        self.hs = hs
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
         self.bulk_evaluator = BulkPushRuleEvaluator(hs)
         # really we want to get all user ids and all profile tags too,
         # since we want the actions for each profile tag for every user and
@@ -35,6 +38,8 @@ class ActionGenerator:
         # event stream, so we just run the rules for a client with no profile
         # tag (ie. we just need all the users).
 
-    async def handle_push_actions_for_event(self, event, context):
+    async def handle_push_actions_for_event(
+        self, event: EventBase, context: EventContext
+    ) -> None:
         with Measure(self.clock, "action_for_event_by_user"):
             await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index f5788c1de7..6211506990 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -15,16 +15,19 @@
 # limitations under the License.
 
 import copy
+from typing import Any, Dict, List
 
 from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
 
 
-def list_with_base_rules(rawrules, use_new_defaults=False):
+def list_with_base_rules(
+    rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
+) -> List[Dict[str, Any]]:
     """Combine the list of rules set by the user with the default push rules
 
     Args:
-        rawrules(list): The rules the user has modified or set.
-        use_new_defaults(bool): Whether to use the new experimental default rules when
+        rawrules: The rules the user has modified or set.
+        use_new_defaults: Whether to use the new experimental default rules when
             appending or prepending default rules.
 
     Returns:
@@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
     return ruleslist
 
 
-def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_append_rules(
+    kind: str,
+    modified_base_rules: Dict[str, Dict[str, Any]],
+    use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
     rules = []
 
     if kind == "override":
@@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
     rules = copy.deepcopy(rules)
     for r in rules:
         # Only modify the actions, keep the conditions the same.
+        assert isinstance(r["rule_id"], str)
         modified = modified_base_rules.get(r["rule_id"])
         if modified:
             r["actions"] = modified["actions"]
@@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
     return rules
 
 
-def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_prepend_rules(
+    kind: str,
+    modified_base_rules: Dict[str, Dict[str, Any]],
+    use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
     rules = []
 
     if kind == "override":
@@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
     rules = copy.deepcopy(rules)
     for r in rules:
         # Only modify the actions, keep the conditions the same.
+        assert isinstance(r["rule_id"], str)
         modified = modified_base_rules.get(r["rule_id"])
         if modified:
             r["actions"] = modified["actions"]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 82a72dc34f..9018f9e20b 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
 
 import attr
 from prometheus_client import Counter
@@ -25,16 +26,16 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import register_cache
+from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.descriptors import lru_cache
 from synapse.util.caches.lrucache import LruCache
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
-logger = logging.getLogger(__name__)
-
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
-rules_by_room = {}
+logger = logging.getLogger(__name__)
 
 
 push_rules_invalidation_counter = Counter(
@@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
     room at once.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
@@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
             resizable=False,
         )
 
-    async def _get_rules_for_event(self, event, context):
+    async def _get_rules_for_event(
+        self, event: EventBase, context: EventContext
+    ) -> Dict[str, List[Dict[str, Any]]]:
         """This gets the rules for all users in the room at the time of the event,
         as well as the push rules for the invitee if the event is an invite.
 
@@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
         return rules_by_user
 
     @lru_cache()
-    def _get_rules_for_room(self, room_id):
+    def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
         """Get the current RulesForRoom object for the given room id
-
-        Returns:
-            RulesForRoom
         """
         # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
         # before any lookup methods get called on it as otherwise there may be
@@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
             self.room_push_rule_cache_metrics,
         )
 
-    async def _get_power_levels_and_sender_level(self, event, context):
+    async def _get_power_levels_and_sender_level(
+        self, event: EventBase, context: EventContext
+    ) -> Tuple[dict, int]:
         prev_state_ids = await context.get_prev_state_ids()
         pl_event_id = prev_state_ids.get(POWER_KEY)
         if pl_event_id:
             # fastpath: if there's a power level event, that's all we need, and
             # not having a power level event is an extreme edge case
-            pl_event = await self.store.get_event(pl_event_id)
-            auth_events = {POWER_KEY: pl_event}
+            auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
         else:
             auth_events_ids = self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=False
             )
-            auth_events = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+            auth_events_dict = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
 
         sender_level = get_user_power_level(event.sender, auth_events)
 
@@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
 
         return pl_event.content if pl_event else {}, sender_level
 
-    async def action_for_event_by_user(self, event, context) -> None:
+    async def action_for_event_by_user(
+        self, event: EventBase, context: EventContext
+    ) -> None:
         """Given an event and context, evaluate the push rules, check if the message
         should increment the unread count, and insert the results into the
         event_push_actions_staging table.
@@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
         count_as_unread = _should_count_as_unread(event, context)
 
         rules_by_user = await self._get_rules_for_event(event, context)
-        actions_by_user = {}
+        actions_by_user = {}  # type: Dict[str, List[Union[dict, str]]]
 
         room_members = await self.store.get_joined_users_from_context(event, context)
 
@@ -198,16 +201,20 @@ class BulkPushRuleEvaluator:
             event, len(room_members), sender_power_level, power_levels
         )
 
-        condition_cache = {}
+        condition_cache = {}  # type: Dict[str, bool]
+
+        # If the event is not a state event check if any users ignore the sender.
+        if not event.is_state():
+            ignorers = await self.store.ignored_by(event.sender)
+        else:
+            ignorers = set()
 
         for uid, rules in rules_by_user.items():
             if event.sender == uid:
                 continue
 
-            if not event.is_state():
-                is_ignored = await self.store.is_ignored_by(event.sender, uid)
-                if is_ignored:
-                    continue
+            if uid in ignorers:
+                continue
 
             display_name = None
             profile_info = room_members.get(uid)
@@ -249,7 +256,13 @@ class BulkPushRuleEvaluator:
         )
 
 
-def _condition_checker(evaluator, conditions, uid, display_name, cache):
+def _condition_checker(
+    evaluator: PushRuleEvaluatorForEvent,
+    conditions: List[dict],
+    uid: str,
+    display_name: str,
+    cache: Dict[str, bool],
+) -> bool:
     for cond in conditions:
         _id = cond.get("_id", None)
         if _id:
@@ -277,15 +290,19 @@ class RulesForRoom:
     """
 
     def __init__(
-        self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+        self,
+        hs: "HomeServer",
+        room_id: str,
+        rules_for_room_cache: LruCache,
+        room_push_rule_cache_metrics: CacheMetric,
     ):
         """
         Args:
-            hs (HomeServer)
-            room_id (str)
+            hs: The HomeServer object.
+            room_id: The room ID.
             rules_for_room_cache: The cache object that caches these
                 RoomsForUser objects.
-            room_push_rule_cache_metrics (CacheMetric)
+            room_push_rule_cache_metrics: The metrics object
         """
         self.room_id = room_id
         self.is_mine_id = hs.is_mine_id
@@ -294,8 +311,10 @@ class RulesForRoom:
 
         self.linearizer = Linearizer(name="rules_for_room")
 
-        self.member_map = {}  # event_id -> (user_id, state)
-        self.rules_by_user = {}  # user_id -> rules
+        # event_id -> (user_id, state)
+        self.member_map = {}  # type: Dict[str, Tuple[str, str]]
+        # user_id -> rules
+        self.rules_by_user = {}  # type: Dict[str, List[Dict[str, dict]]]
 
         # The last state group we updated the caches for. If the state_group of
         # a new event comes along, we know that we can just return the cached
@@ -315,7 +334,7 @@ class RulesForRoom:
         # calculate push for)
         # These never need to be invalidated as we will never set up push for
         # them.
-        self.uninteresting_user_set = set()
+        self.uninteresting_user_set = set()  # type: Set[str]
 
         # We need to be clever on the invalidating caches callbacks, as
         # otherwise the invalidation callback holds a reference to the object,
@@ -325,7 +344,9 @@ class RulesForRoom:
         # to self around in the callback.
         self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
 
-    async def get_rules(self, event, context):
+    async def get_rules(
+        self, event: EventBase, context: EventContext
+    ) -> Dict[str, List[Dict[str, dict]]]:
         """Given an event context return the rules for all users who are
         currently in the room.
         """
@@ -356,6 +377,8 @@ class RulesForRoom:
             else:
                 current_state_ids = await context.get_current_state_ids()
                 push_rules_delta_state_cache_metric.inc_misses()
+            # Ensure the state IDs exist.
+            assert current_state_ids is not None
 
             push_rules_state_size_counter.inc(len(current_state_ids))
 
@@ -420,18 +443,23 @@ class RulesForRoom:
         return ret_rules_by_user
 
     async def _update_rules_with_member_event_ids(
-        self, ret_rules_by_user, member_event_ids, state_group, event
-    ):
+        self,
+        ret_rules_by_user: Dict[str, list],
+        member_event_ids: Dict[str, str],
+        state_group: Optional[int],
+        event: EventBase,
+    ) -> None:
         """Update the partially filled rules_by_user dict by fetching rules for
         any newly joined users in the `member_event_ids` list.
 
         Args:
-            ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+            ret_rules_by_user: Partially filled dict of push rules. Gets
                 updated with any new rules.
-            member_event_ids (dict): Dict of user id to event id for membership events
+            member_event_ids: Dict of user id to event id for membership events
                 that have happened since the last time we filled rules_by_user
             state_group: The state group we are currently computing push rules
                 for. Used when updating the cache.
+            event: The event we are currently computing push rules for.
         """
         sequence = self.sequence
 
@@ -449,19 +477,19 @@ class RulesForRoom:
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())
 
-        user_ids = {
+        joined_user_ids = {
             user_id
             for user_id, membership in members.values()
             if membership == Membership.JOIN
         }
 
-        logger.debug("Joined: %r", user_ids)
+        logger.debug("Joined: %r", joined_user_ids)
 
         # Previously we only considered users with pushers or read receipts in that
         # room. We can't do this anymore because we use push actions to calculate unread
         # counts, which don't rely on the user having pushers or sent a read receipt into
         # the room. Therefore we just need to filter for local users here.
-        user_ids = list(filter(self.is_mine_id, user_ids))
+        user_ids = list(filter(self.is_mine_id, joined_user_ids))
 
         rules_by_user = await self.store.bulk_get_push_rules(
             user_ids, on_invalidate=self.invalidate_all_cb
@@ -473,7 +501,7 @@ class RulesForRoom:
 
         self.update_cache(sequence, members, ret_rules_by_user, state_group)
 
-    def invalidate_all(self):
+    def invalidate_all(self) -> None:
         # Note: Don't hand this function directly to an invalidation callback
         # as it keeps a reference to self and will stop this instance from being
         # GC'd if it gets dropped from the rules_to_user cache. Instead use
@@ -485,7 +513,7 @@ class RulesForRoom:
         self.rules_by_user = {}
         push_rules_invalidation_counter.inc()
 
-    def update_cache(self, sequence, members, rules_by_user, state_group):
+    def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
         if sequence == self.sequence:
             self.member_map.update(members)
             self.rules_by_user = rules_by_user
@@ -506,7 +534,7 @@ class _Invalidation:
     cache = attr.ib(type=LruCache)
     room_id = attr.ib(type=str)
 
-    def __call__(self):
+    def __call__(self) -> None:
         rules = self.cache.get(self.room_id, None, update_metrics=False)
         if rules:
             rules.invalidate_all()
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index a59b639f15..0cadba761a 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -14,24 +14,27 @@
 # limitations under the License.
 
 import copy
+from typing import Any, Dict, List, Optional
 
 from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+from synapse.types import UserID
 
 
-def format_push_rules_for_user(user, ruleslist):
+def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
     """Converts a list of rawrules and a enabled map into nested dictionaries
     to match the Matrix client-server format for push rules"""
 
     # We're going to be mutating this a lot, so do a deep copy
     ruleslist = copy.deepcopy(ruleslist)
 
-    rules = {"global": {}, "device": {}}
+    rules = {
+        "global": {},
+        "device": {},
+    }  # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
 
     rules["global"] = _add_empty_priority_class_arrays(rules["global"])
 
     for r in ruleslist:
-        rulearray = None
-
         template_name = _priority_class_to_template_name(r["priority_class"])
 
         # Remove internal stuff.
@@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
     return rules
 
 
-def _add_empty_priority_class_arrays(d):
+def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
     for pc in PRIORITY_CLASS_MAP.keys():
         d[pc] = []
     return d
 
 
-def _rule_to_template(rule):
+def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
     unscoped_rule_id = None
     if "rule_id" in rule:
         unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
@@ -82,6 +85,10 @@ def _rule_to_template(rule):
             return None
         templaterule = {"actions": rule["actions"]}
         templaterule["pattern"] = thecond["pattern"]
+    else:
+        # This should not be reached unless this function is not kept in sync
+        # with PRIORITY_CLASS_INVERSE_MAP.
+        raise ValueError("Unexpected template_name: %s" % (template_name,))
 
     if unscoped_rule_id:
         templaterule["rule_id"] = unscoped_rule_id
@@ -90,9 +97,9 @@ def _rule_to_template(rule):
     return templaterule
 
 
-def _rule_id_from_namespaced(in_rule_id):
+def _rule_id_from_namespaced(in_rule_id: str) -> str:
     return in_rule_id.split("/")[-1]
 
 
-def _priority_class_to_template_name(pc):
+def _priority_class_to_template_name(pc: int) -> str:
     return PRIORITY_CLASS_INVERSE_MAP[pc]
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c6763971ee..4ac1b31748 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,11 +14,17 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Dict, List, Optional
 
+from twisted.internet.base import DelayedCall
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import RoomStreamToken
+from synapse.push import Pusher, PusherConfig, ThrottleParams
+from synapse.push.mailer import Mailer
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +52,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
 INCLUDE_ALL_UNREAD_NOTIFS = False
 
 
-class EmailPusher:
+class EmailPusher(Pusher):
     """
     A pusher that sends email notifications about events (approximately)
     when they happen.
@@ -54,37 +60,30 @@ class EmailPusher:
     factor out the common parts
     """
 
-    def __init__(self, hs, pusherdict, mailer):
-        self.hs = hs
+    def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
+        super().__init__(hs, pusher_config)
         self.mailer = mailer
 
         self.store = self.hs.get_datastore()
-        self.clock = self.hs.get_clock()
-        self.pusher_id = pusherdict["id"]
-        self.user_id = pusherdict["user_name"]
-        self.app_id = pusherdict["app_id"]
-        self.email = pusherdict["pushkey"]
-        self.last_stream_ordering = pusherdict["last_stream_ordering"]
-        self.timed_call = None
-        self.throttle_params = None
-
-        # See httppusher
-        self.max_stream_ordering = None
+        self.email = pusher_config.pushkey
+        self.timed_call = None  # type: Optional[DelayedCall]
+        self.throttle_params = {}  # type: Dict[str, ThrottleParams]
+        self._inited = False
 
         self._is_processing = False
 
-    def on_started(self, should_check_for_notifs):
+    def on_started(self, should_check_for_notifs: bool) -> None:
         """Called when this pusher has been started.
 
         Args:
-            should_check_for_notifs (bool): Whether we should immediately
+            should_check_for_notifs: Whether we should immediately
                 check for push to send. Set to False only if it's known there
                 is nothing to send
         """
         if should_check_for_notifs and self.mailer is not None:
             self._start_processing()
 
-    def on_stop(self):
+    def on_stop(self) -> None:
         if self.timed_call:
             try:
                 self.timed_call.cancel()
@@ -92,37 +91,23 @@ class EmailPusher:
                 pass
             self.timed_call = None
 
-    def on_new_notifications(self, max_token: RoomStreamToken):
-        # We just use the minimum stream ordering and ignore the vector clock
-        # component. This is safe to do as long as we *always* ignore the vector
-        # clock components.
-        max_stream_ordering = max_token.stream
-
-        if self.max_stream_ordering:
-            self.max_stream_ordering = max(
-                max_stream_ordering, self.max_stream_ordering
-            )
-        else:
-            self.max_stream_ordering = max_stream_ordering
-        self._start_processing()
-
-    def on_new_receipts(self, min_stream_id, max_stream_id):
+    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
         # We could wake up and cancel the timer but there tend to be quite a
         # lot of read receipts so it's probably less work to just let the
         # timer fire
         pass
 
-    def on_timer(self):
+    def on_timer(self) -> None:
         self.timed_call = None
         self._start_processing()
 
-    def _start_processing(self):
+    def _start_processing(self) -> None:
         if self._is_processing:
             return
 
         run_as_background_process("emailpush.process", self._process)
 
-    def _pause_processing(self):
+    def _pause_processing(self) -> None:
         """Used by tests to temporarily pause processing of events.
 
         Asserts that its not currently processing.
@@ -130,25 +115,27 @@ class EmailPusher:
         assert not self._is_processing
         self._is_processing = True
 
-    def _resume_processing(self):
+    def _resume_processing(self) -> None:
         """Used by tests to resume processing of events after pausing.
         """
         assert self._is_processing
         self._is_processing = False
         self._start_processing()
 
-    async def _process(self):
+    async def _process(self) -> None:
         # we should never get here if we are already processing
         assert not self._is_processing
 
         try:
             self._is_processing = True
 
-            if self.throttle_params is None:
+            if not self._inited:
                 # this is our first loop: load up the throttle params
+                assert self.pusher_id is not None
                 self.throttle_params = await self.store.get_throttle_params_by_room(
                     self.pusher_id
                 )
+                self._inited = True
 
             # if the max ordering changes while we're running _unsafe_process,
             # call it again, and so on until we've caught up.
@@ -163,17 +150,18 @@ class EmailPusher:
         finally:
             self._is_processing = False
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         """
         Main logic of the push loop without the wrapper function that sets
         up logging, measures and guards against multiple instances of it
         being run.
         """
         start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
-        fn = self.store.get_unread_push_actions_for_user_in_range_for_email
-        unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
+        unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
+            self.user_id, start, self.max_stream_ordering
+        )
 
-        soonest_due_at = None
+        soonest_due_at = None  # type: Optional[int]
 
         if not unprocessed:
             await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@@ -230,11 +218,9 @@ class EmailPusher:
                 self.seconds_until(soonest_due_at), self.on_timer
             )
 
-    async def save_last_stream_ordering_and_success(self, last_stream_ordering):
-        if last_stream_ordering is None:
-            # This happens if we haven't yet processed anything
-            return
-
+    async def save_last_stream_ordering_and_success(
+        self, last_stream_ordering: int
+    ) -> None:
         self.last_stream_ordering = last_stream_ordering
         pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
             self.app_id,
@@ -248,28 +234,30 @@ class EmailPusher:
             # lets just stop and return.
             self.on_stop()
 
-    def seconds_until(self, ts_msec):
+    def seconds_until(self, ts_msec: int) -> float:
         secs = (ts_msec - self.clock.time_msec()) / 1000
         return max(secs, 0)
 
-    def get_room_throttle_ms(self, room_id):
+    def get_room_throttle_ms(self, room_id: str) -> int:
         if room_id in self.throttle_params:
-            return self.throttle_params[room_id]["throttle_ms"]
+            return self.throttle_params[room_id].throttle_ms
         else:
             return 0
 
-    def get_room_last_sent_ts(self, room_id):
+    def get_room_last_sent_ts(self, room_id: str) -> int:
         if room_id in self.throttle_params:
-            return self.throttle_params[room_id]["last_sent_ts"]
+            return self.throttle_params[room_id].last_sent_ts
         else:
             return 0
 
-    def room_ready_to_notify_at(self, room_id):
+    def room_ready_to_notify_at(self, room_id: str) -> int:
         """
         Determines whether throttling should prevent us from sending an email
         for the given room
-        Returns: The timestamp when we are next allowed to send an email notif
-        for this room
+
+        Returns:
+            The timestamp when we are next allowed to send an email notif
+            for this room
         """
         last_sent_ts = self.get_room_last_sent_ts(room_id)
         throttle_ms = self.get_room_throttle_ms(room_id)
@@ -277,7 +265,9 @@ class EmailPusher:
         may_send_at = last_sent_ts + throttle_ms
         return may_send_at
 
-    async def sent_notif_update_throttle(self, room_id, notified_push_action):
+    async def sent_notif_update_throttle(
+        self, room_id: str, notified_push_action: dict
+    ) -> None:
         # We have sent a notification, so update the throttle accordingly.
         # If the event that triggered the notif happened more than
         # THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@@ -307,15 +297,15 @@ class EmailPusher:
                 new_throttle_ms = min(
                     current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
                 )
-        self.throttle_params[room_id] = {
-            "last_sent_ts": self.clock.time_msec(),
-            "throttle_ms": new_throttle_ms,
-        }
+        self.throttle_params[room_id] = ThrottleParams(
+            self.clock.time_msec(), new_throttle_ms,
+        )
+        assert self.pusher_id is not None
         await self.store.set_throttle_params(
             self.pusher_id, room_id, self.throttle_params[room_id]
         )
 
-    async def send_notification(self, push_actions, reason):
+    async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
         logger.info("Sending notif email for user %r", self.user_id)
 
         await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index eff0975b6a..e048b0d59e 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -14,19 +14,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import urllib.parse
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
 
 from prometheus_client import Counter
 
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
 from synapse.logging import opentracing
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import PusherConfigException
-from synapse.types import RoomStreamToken
+from synapse.push import Pusher, PusherConfig, PusherConfigException
 
 from . import push_rule_evaluator, push_tools
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 http_push_processed_counter = Counter(
@@ -50,91 +55,76 @@ http_badges_failed_counter = Counter(
 )
 
 
-class HttpPusher:
+class HttpPusher(Pusher):
     INITIAL_BACKOFF_SEC = 1  # in seconds because that's what Twisted takes
     MAX_BACKOFF_SEC = 60 * 60
 
     # This one's in ms because we compare it against the clock
     GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
 
-    def __init__(self, hs, pusherdict):
-        self.hs = hs
-        self.store = self.hs.get_datastore()
+    def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+        super().__init__(hs, pusher_config)
         self.storage = self.hs.get_storage()
-        self.clock = self.hs.get_clock()
-        self.state_handler = self.hs.get_state_handler()
-        self.user_id = pusherdict["user_name"]
-        self.app_id = pusherdict["app_id"]
-        self.app_display_name = pusherdict["app_display_name"]
-        self.device_display_name = pusherdict["device_display_name"]
-        self.pushkey = pusherdict["pushkey"]
-        self.pushkey_ts = pusherdict["ts"]
-        self.data = pusherdict["data"]
-        self.last_stream_ordering = pusherdict["last_stream_ordering"]
+        self.app_display_name = pusher_config.app_display_name
+        self.device_display_name = pusher_config.device_display_name
+        self.pushkey_ts = pusher_config.ts
+        self.data = pusher_config.data
         self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
-        self.failing_since = pusherdict["failing_since"]
+        self.failing_since = pusher_config.failing_since
         self.timed_call = None
         self._is_processing = False
         self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
 
-        # This is the highest stream ordering we know it's safe to process.
-        # When new events arrive, we'll be given a window of new events: we
-        # should honour this rather than just looking for anything higher
-        # because of potential out-of-order event serialisation. This starts
-        # off as None though as we don't know any better.
-        self.max_stream_ordering = None
-
-        if "data" not in pusherdict:
-            raise PusherConfigException("No 'data' key for HTTP pusher")
-        self.data = pusherdict["data"]
+        self.data = pusher_config.data
+        if self.data is None:
+            raise PusherConfigException("'data' key can not be null for HTTP pusher")
 
         self.name = "%s/%s/%s" % (
-            pusherdict["user_name"],
-            pusherdict["app_id"],
-            pusherdict["pushkey"],
+            pusher_config.user_name,
+            pusher_config.app_id,
+            pusher_config.pushkey,
         )
 
-        if self.data is None:
-            raise PusherConfigException("data can not be null for HTTP pusher")
-
+        # Validate that there's a URL and it is of the proper form.
         if "url" not in self.data:
             raise PusherConfigException("'url' required in data for HTTP pusher")
-        self.url = self.data["url"]
-        self.http_client = hs.get_proxied_http_client()
+
+        url = self.data["url"]
+        if not isinstance(url, str):
+            raise PusherConfigException("'url' must be a string")
+        url_parts = urllib.parse.urlparse(url)
+        # Note that the specification also says the scheme must be HTTPS, but
+        # it isn't up to the homeserver to verify that.
+        if url_parts.path != "/_matrix/push/v1/notify":
+            raise PusherConfigException(
+                "'url' must have a path of '/_matrix/push/v1/notify'"
+            )
+
+        self.url = url
+        self.http_client = hs.get_proxied_blacklisted_http_client()
         self.data_minus_url = {}
         self.data_minus_url.update(self.data)
         del self.data_minus_url["url"]
 
-    def on_started(self, should_check_for_notifs):
+    def on_started(self, should_check_for_notifs: bool) -> None:
         """Called when this pusher has been started.
 
         Args:
-            should_check_for_notifs (bool): Whether we should immediately
+            should_check_for_notifs: Whether we should immediately
                 check for push to send. Set to False only if it's known there
                 is nothing to send
         """
         if should_check_for_notifs:
             self._start_processing()
 
-    def on_new_notifications(self, max_token: RoomStreamToken):
-        # We just use the minimum stream ordering and ignore the vector clock
-        # component. This is safe to do as long as we *always* ignore the vector
-        # clock components.
-        max_stream_ordering = max_token.stream
-
-        self.max_stream_ordering = max(
-            max_stream_ordering, self.max_stream_ordering or 0
-        )
-        self._start_processing()
-
-    def on_new_receipts(self, min_stream_id, max_stream_id):
+    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
         # Note that the min here shouldn't be relied upon to be accurate.
 
         # We could check the receipts are actually m.read receipts here,
         # but currently that's the only type of receipt anyway...
         run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
 
-    async def _update_badge(self):
+    async def _update_badge(self) -> None:
         # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
         # to be largely redundant. perhaps we can remove it.
         badge = await push_tools.get_badge_count(
@@ -144,10 +134,10 @@ class HttpPusher:
         )
         await self._send_badge(badge)
 
-    def on_timer(self):
+    def on_timer(self) -> None:
         self._start_processing()
 
-    def on_stop(self):
+    def on_stop(self) -> None:
         if self.timed_call:
             try:
                 self.timed_call.cancel()
@@ -155,13 +145,13 @@ class HttpPusher:
                 pass
             self.timed_call = None
 
-    def _start_processing(self):
+    def _start_processing(self) -> None:
         if self._is_processing:
             return
 
         run_as_background_process("httppush.process", self._process)
 
-    async def _process(self):
+    async def _process(self) -> None:
         # we should never get here if we are already processing
         assert not self._is_processing
 
@@ -180,15 +170,13 @@ class HttpPusher:
         finally:
             self._is_processing = False
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         """
         Looks for unset notifications and dispatch them, in order
         Never call this directly: use _process which will only allow this to
         run once per pusher.
         """
-
-        fn = self.store.get_unread_push_actions_for_user_in_range_for_http
-        unprocessed = await fn(
+        unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
             self.user_id, self.last_stream_ordering, self.max_stream_ordering
         )
 
@@ -257,17 +245,12 @@ class HttpPusher:
                     )
                     self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
                     self.last_stream_ordering = push_action["stream_ordering"]
-                    pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
+                    await self.store.update_pusher_last_stream_ordering(
                         self.app_id,
                         self.pushkey,
                         self.user_id,
                         self.last_stream_ordering,
                     )
-                    if not pusher_still_exists:
-                        # The pusher has been deleted while we were processing, so
-                        # lets just stop and return.
-                        self.on_stop()
-                        return
 
                     self.failing_since = None
                     await self.store.update_pusher_failing_since(
@@ -283,7 +266,7 @@ class HttpPusher:
                     )
                     break
 
-    async def _process_one(self, push_action):
+    async def _process_one(self, push_action: dict) -> bool:
         if "notify" not in push_action["actions"]:
             return True
 
@@ -314,7 +297,9 @@ class HttpPusher:
                     await self.hs.remove_pusher(self.app_id, pk, self.user_id)
         return True
 
-    async def _build_notification_dict(self, event, tweaks, badge):
+    async def _build_notification_dict(
+        self, event: EventBase, tweaks: Dict[str, bool], badge: int
+    ) -> Dict[str, Any]:
         priority = "low"
         if (
             event.type == EventTypes.Encrypted
@@ -325,6 +310,8 @@ class HttpPusher:
             #  or may do so (i.e. is encrypted so has unknown effects).
             priority = "high"
 
+        # This was checked in the __init__, but mypy doesn't seem to know that.
+        assert self.data is not None
         if self.data.get("format") == "event_id_only":
             d = {
                 "notification": {
@@ -344,9 +331,7 @@ class HttpPusher:
             }
             return d
 
-        ctx = await push_tools.get_context_for_event(
-            self.storage, self.state_handler, event, self.user_id
-        )
+        ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
 
         d = {
             "notification": {
@@ -386,7 +371,9 @@ class HttpPusher:
 
         return d
 
-    async def dispatch_push(self, event, tweaks, badge):
+    async def dispatch_push(
+        self, event: EventBase, tweaks: Dict[str, bool], badge: int
+    ) -> Union[bool, Iterable[str]]:
         notification_dict = await self._build_notification_dict(event, tweaks, badge)
         if not notification_dict:
             return []
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 38195c8eea..8a6dcff30d 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -19,7 +19,7 @@ import logging
 import urllib.parse
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-from typing import Iterable, List, TypeVar
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
 
 import bleach
 import jinja2
@@ -27,16 +27,20 @@ import jinja2
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import StoreError
 from synapse.config.emailconfig import EmailSubjectConfig
+from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable
 from synapse.push.presentable_names import (
     calculate_room_name,
     descriptor_from_member_events,
     name_from_member_event,
 )
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
 from synapse.util.async_helpers import concurrently_execute
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 T = TypeVar("T")
@@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
 
 
 class Mailer:
-    def __init__(self, hs, app_name, template_html, template_text):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        app_name: str,
+        template_html: jinja2.Template,
+        template_text: jinja2.Template,
+    ):
         self.hs = hs
         self.template_html = template_html
         self.template_text = template_text
@@ -108,17 +118,19 @@ class Mailer:
 
         logger.info("Created Mailer for app_name %s" % app_name)
 
-    async def send_password_reset_mail(self, email_address, token, client_secret, sid):
+    async def send_password_reset_mail(
+        self, email_address: str, token: str, client_secret: str, sid: str
+    ) -> None:
         """Send an email with a password reset link to a user
 
         Args:
-            email_address (str): Email address we're sending the password
+            email_address: Email address we're sending the password
                 reset to
-            token (str): Unique token generated by the server to verify
+            token: Unique token generated by the server to verify
                 the email was received
-            client_secret (str): Unique token generated by the client to
+            client_secret: Unique token generated by the client to
                 group together multiple email sending attempts
-            sid (str): The generated session ID
+            sid: The generated session ID
         """
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
@@ -136,17 +148,19 @@ class Mailer:
             template_vars,
         )
 
-    async def send_registration_mail(self, email_address, token, client_secret, sid):
+    async def send_registration_mail(
+        self, email_address: str, token: str, client_secret: str, sid: str
+    ) -> None:
         """Send an email with a registration confirmation link to a user
 
         Args:
-            email_address (str): Email address we're sending the registration
+            email_address: Email address we're sending the registration
                 link to
-            token (str): Unique token generated by the server to verify
+            token: Unique token generated by the server to verify
                 the email was received
-            client_secret (str): Unique token generated by the client to
+            client_secret: Unique token generated by the client to
                 group together multiple email sending attempts
-            sid (str): The generated session ID
+            sid: The generated session ID
         """
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
@@ -164,18 +178,20 @@ class Mailer:
             template_vars,
         )
 
-    async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+    async def send_add_threepid_mail(
+        self, email_address: str, token: str, client_secret: str, sid: str
+    ) -> None:
         """Send an email with a validation link to a user for adding a 3pid to their account
 
         Args:
-            email_address (str): Email address we're sending the validation link to
+            email_address: Email address we're sending the validation link to
 
-            token (str): Unique token generated by the server to verify the email was received
+            token: Unique token generated by the server to verify the email was received
 
-            client_secret (str): Unique token generated by the client to group together
+            client_secret: Unique token generated by the client to group together
                 multiple email sending attempts
 
-            sid (str): The generated session ID
+            sid: The generated session ID
         """
         params = {"token": token, "client_secret": client_secret, "sid": sid}
         link = (
@@ -194,8 +210,13 @@ class Mailer:
         )
 
     async def send_notification_mail(
-        self, app_id, user_id, email_address, push_actions, reason
-    ):
+        self,
+        app_id: str,
+        user_id: str,
+        email_address: str,
+        push_actions: Iterable[Dict[str, Any]],
+        reason: Dict[str, Any],
+    ) -> None:
         """Send email regarding a user's room notifications"""
         rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
 
@@ -203,7 +224,7 @@ class Mailer:
             [pa["event_id"] for pa in push_actions]
         )
 
-        notifs_by_room = {}
+        notifs_by_room = {}  # type: Dict[str, List[Dict[str, Any]]]
         for pa in push_actions:
             notifs_by_room.setdefault(pa["room_id"], []).append(pa)
 
@@ -246,9 +267,21 @@ class Mailer:
             fallback_to_members=True,
         )
 
-        summary_text = await self.make_summary_text(
-            notifs_by_room, state_by_room, notif_events, user_id, reason
-        )
+        if len(notifs_by_room) == 1:
+            # Only one room has new stuff
+            room_id = list(notifs_by_room.keys())[0]
+
+            summary_text = await self.make_summary_text_single_room(
+                room_id,
+                notifs_by_room[room_id],
+                state_by_room[room_id],
+                notif_events,
+                user_id,
+            )
+        else:
+            summary_text = await self.make_summary_text(
+                notifs_by_room, state_by_room, notif_events, reason
+            )
 
         template_vars = {
             "user_display_name": user_display_name,
@@ -262,7 +295,9 @@ class Mailer:
 
         await self.send_email(email_address, summary_text, template_vars)
 
-    async def send_email(self, email_address, subject, extra_template_vars):
+    async def send_email(
+        self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+    ) -> None:
         """Send an email with the given information and template text"""
         try:
             from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -315,8 +350,13 @@ class Mailer:
         )
 
     async def get_room_vars(
-        self, room_id, user_id, notifs, notif_events, room_state_ids
-    ):
+        self,
+        room_id: str,
+        user_id: str,
+        notifs: Iterable[Dict[str, Any]],
+        notif_events: Dict[str, EventBase],
+        room_state_ids: StateMap[str],
+    ) -> Dict[str, Any]:
         # Check if one of the notifs is an invite event for the user.
         is_invite = False
         for n in notifs:
@@ -334,7 +374,7 @@ class Mailer:
             "notifs": [],
             "invite": is_invite,
             "link": self.make_room_link(room_id),
-        }
+        }  # type: Dict[str, Any]
 
         if not is_invite:
             for n in notifs:
@@ -365,7 +405,13 @@ class Mailer:
 
         return room_vars
 
-    async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
+    async def get_notif_vars(
+        self,
+        notif: Dict[str, Any],
+        user_id: str,
+        notif_event: EventBase,
+        room_state_ids: StateMap[str],
+    ) -> Dict[str, Any]:
         results = await self.store.get_events_around(
             notif["room_id"],
             notif["event_id"],
@@ -391,7 +437,9 @@ class Mailer:
 
         return ret
 
-    async def get_message_vars(self, notif, event, room_state_ids):
+    async def get_message_vars(
+        self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
+    ) -> Optional[Dict[str, Any]]:
         if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
             return None
 
@@ -432,7 +480,9 @@ class Mailer:
 
         return ret
 
-    def add_text_message_vars(self, messagevars, event):
+    def add_text_message_vars(
+        self, messagevars: Dict[str, Any], event: EventBase
+    ) -> None:
         msgformat = event.content.get("format")
 
         messagevars["format"] = msgformat
@@ -445,142 +495,188 @@ class Mailer:
         elif body:
             messagevars["body_text_html"] = safe_text(body)
 
-        return messagevars
-
-    def add_image_message_vars(self, messagevars, event):
-        messagevars["image_url"] = event.content["url"]
+    def add_image_message_vars(
+        self, messagevars: Dict[str, Any], event: EventBase
+    ) -> None:
+        """
+        Potentially add an image URL to the message variables.
+        """
+        if "url" in event.content:
+            messagevars["image_url"] = event.content["url"]
+
+    async def make_summary_text_single_room(
+        self,
+        room_id: str,
+        notifs: List[Dict[str, Any]],
+        room_state_ids: StateMap[str],
+        notif_events: Dict[str, EventBase],
+        user_id: str,
+    ) -> str:
+        """
+        Make a summary text for the email when only a single room has notifications.
 
-        return messagevars
+        Args:
+            room_id: The ID of the room.
+            notifs: The notifications for this room.
+            room_state_ids: The state map for the room.
+            notif_events: A map of event ID -> notification event.
+            user_id: The user receiving the notification.
+
+        Returns:
+            The summary text.
+        """
+        # If the room has some kind of name, use it, but we don't
+        # want the generated-from-names one here otherwise we'll
+        # end up with, "new message from Bob in the Bob room"
+        room_name = await calculate_room_name(
+            self.store, room_state_ids, user_id, fallback_to_members=False
+        )
 
-    async def make_summary_text(
-        self, notifs_by_room, room_state_ids, notif_events, user_id, reason
-    ):
-        if len(notifs_by_room) == 1:
-            # Only one room has new stuff
-            room_id = list(notifs_by_room.keys())[0]
+        # See if one of the notifs is an invite event for the user
+        invite_event = None
+        for n in notifs:
+            ev = notif_events[n["event_id"]]
+            if ev.type == EventTypes.Member and ev.state_key == user_id:
+                if ev.content.get("membership") == Membership.INVITE:
+                    invite_event = ev
+                    break
 
-            # If the room has some kind of name, use it, but we don't
-            # want the generated-from-names one here otherwise we'll
-            # end up with, "new message from Bob in the Bob room"
-            room_name = await calculate_room_name(
-                self.store, room_state_ids[room_id], user_id, fallback_to_members=False
+        if invite_event:
+            inviter_member_event_id = room_state_ids.get(
+                ("m.room.member", invite_event.sender)
             )
-
-            # See if one of the notifs is an invite event for the user
-            invite_event = None
-            for n in notifs_by_room[room_id]:
-                ev = notif_events[n["event_id"]]
-                if ev.type == EventTypes.Member and ev.state_key == user_id:
-                    if ev.content.get("membership") == Membership.INVITE:
-                        invite_event = ev
-                        break
-
-            if invite_event:
-                inviter_member_event_id = room_state_ids[room_id].get(
-                    ("m.room.member", invite_event.sender)
+            inviter_name = invite_event.sender
+            if inviter_member_event_id:
+                inviter_member_event = await self.store.get_event(
+                    inviter_member_event_id, allow_none=True
                 )
-                inviter_name = invite_event.sender
-                if inviter_member_event_id:
-                    inviter_member_event = await self.store.get_event(
-                        inviter_member_event_id, allow_none=True
-                    )
-                    if inviter_member_event:
-                        inviter_name = name_from_member_event(inviter_member_event)
-
-                if room_name is None:
-                    return self.email_subjects.invite_from_person % {
-                        "person": inviter_name,
-                        "app": self.app_name,
-                    }
-                else:
-                    return self.email_subjects.invite_from_person_to_room % {
-                        "person": inviter_name,
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
+                if inviter_member_event:
+                    inviter_name = name_from_member_event(inviter_member_event)
 
-            sender_name = None
-            if len(notifs_by_room[room_id]) == 1:
-                # There is just the one notification, so give some detail
-                event = notif_events[notifs_by_room[room_id][0]["event_id"]]
-                if ("m.room.member", event.sender) in room_state_ids[room_id]:
-                    state_event_id = room_state_ids[room_id][
-                        ("m.room.member", event.sender)
-                    ]
-                    state_event = await self.store.get_event(state_event_id)
-                    sender_name = name_from_member_event(state_event)
-
-                if sender_name is not None and room_name is not None:
-                    return self.email_subjects.message_from_person_in_room % {
-                        "person": sender_name,
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
-                elif sender_name is not None:
-                    return self.email_subjects.message_from_person % {
-                        "person": sender_name,
-                        "app": self.app_name,
-                    }
-            else:
-                # There's more than one notification for this room, so just
-                # say there are several
-                if room_name is not None:
-                    return self.email_subjects.messages_in_room % {
-                        "room": room_name,
-                        "app": self.app_name,
-                    }
-                else:
-                    # If the room doesn't have a name, say who the messages
-                    # are from explicitly to avoid, "messages in the Bob room"
-                    sender_ids = list(
-                        {
-                            notif_events[n["event_id"]].sender
-                            for n in notifs_by_room[room_id]
-                        }
-                    )
-
-                    member_events = await self.store.get_events(
-                        [
-                            room_state_ids[room_id][("m.room.member", s)]
-                            for s in sender_ids
-                        ]
-                    )
-
-                    return self.email_subjects.messages_from_person % {
-                        "person": descriptor_from_member_events(member_events.values()),
-                        "app": self.app_name,
-                    }
-        else:
-            # Stuff's happened in multiple different rooms
-
-            # ...but we still refer to the 'reason' room which triggered the mail
-            if reason["room_name"] is not None:
-                return self.email_subjects.messages_in_room_and_others % {
-                    "room": reason["room_name"],
+            if room_name is None:
+                return self.email_subjects.invite_from_person % {
+                    "person": inviter_name,
                     "app": self.app_name,
                 }
-            else:
-                # If the reason room doesn't have a name, say who the messages
-                # are from explicitly to avoid, "messages in the Bob room"
-                room_id = reason["room_id"]
-
-                sender_ids = list(
-                    {
-                        notif_events[n["event_id"]].sender
-                        for n in notifs_by_room[room_id]
-                    }
-                )
 
-                member_events = await self.store.get_events(
-                    [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
-                )
+            return self.email_subjects.invite_from_person_to_room % {
+                "person": inviter_name,
+                "room": room_name,
+                "app": self.app_name,
+            }
 
-                return self.email_subjects.messages_from_person_and_others % {
-                    "person": descriptor_from_member_events(member_events.values()),
+        if len(notifs) == 1:
+            # There is just the one notification, so give some detail
+            sender_name = None
+            event = notif_events[notifs[0]["event_id"]]
+            if ("m.room.member", event.sender) in room_state_ids:
+                state_event_id = room_state_ids[("m.room.member", event.sender)]
+                state_event = await self.store.get_event(state_event_id)
+                sender_name = name_from_member_event(state_event)
+
+            if sender_name is not None and room_name is not None:
+                return self.email_subjects.message_from_person_in_room % {
+                    "person": sender_name,
+                    "room": room_name,
                     "app": self.app_name,
                 }
+            elif sender_name is not None:
+                return self.email_subjects.message_from_person % {
+                    "person": sender_name,
+                    "app": self.app_name,
+                }
+
+            # The sender is unknown, just use the room name (or ID).
+            return self.email_subjects.messages_in_room % {
+                "room": room_name or room_id,
+                "app": self.app_name,
+            }
+        else:
+            # There's more than one notification for this room, so just
+            # say there are several
+            if room_name is not None:
+                return self.email_subjects.messages_in_room % {
+                    "room": room_name,
+                    "app": self.app_name,
+                }
+
+            return await self.make_summary_text_from_member_events(
+                room_id, notifs, room_state_ids, notif_events
+            )
+
+    async def make_summary_text(
+        self,
+        notifs_by_room: Dict[str, List[Dict[str, Any]]],
+        room_state_ids: Dict[str, StateMap[str]],
+        notif_events: Dict[str, EventBase],
+        reason: Dict[str, Any],
+    ) -> str:
+        """
+        Make a summary text for the email when multiple rooms have notifications.
+
+        Args:
+            notifs_by_room: A map of room ID to the notifications for that room.
+            room_state_ids: A map of room ID to the state map for that room.
+            notif_events: A map of event ID -> notification event.
+            reason: The reason this notification is being sent.
+
+        Returns:
+            The summary text.
+        """
+        # Stuff's happened in multiple different rooms
+        # ...but we still refer to the 'reason' room which triggered the mail
+        if reason["room_name"] is not None:
+            return self.email_subjects.messages_in_room_and_others % {
+                "room": reason["room_name"],
+                "app": self.app_name,
+            }
+
+        room_id = reason["room_id"]
+        return await self.make_summary_text_from_member_events(
+            room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events
+        )
+
+    async def make_summary_text_from_member_events(
+        self,
+        room_id: str,
+        notifs: List[Dict[str, Any]],
+        room_state_ids: StateMap[str],
+        notif_events: Dict[str, EventBase],
+    ) -> str:
+        """
+        Make a summary text for the email when only a single room has notifications.
+
+        Args:
+            room_id: The ID of the room.
+            notifs: The notifications for this room.
+            room_state_ids: The state map for the room.
+            notif_events: A map of event ID -> notification event.
+
+        Returns:
+            The summary text.
+        """
+        # If the room doesn't have a name, say who the messages
+        # are from explicitly to avoid, "messages in the Bob room"
+        sender_ids = {notif_events[n["event_id"]].sender for n in notifs}
+
+        member_events = await self.store.get_events(
+            [room_state_ids[("m.room.member", s)] for s in sender_ids]
+        )
 
-    def make_room_link(self, room_id):
+        # There was a single sender.
+        if len(sender_ids) == 1:
+            return self.email_subjects.messages_from_person % {
+                "person": descriptor_from_member_events(member_events.values()),
+                "app": self.app_name,
+            }
+
+        # There was more than one sender, use the first one and a tweaked template.
+        return self.email_subjects.messages_from_person_and_others % {
+            "person": descriptor_from_member_events(list(member_events.values())[:1]),
+            "app": self.app_name,
+        }
+
+    def make_room_link(self, room_id: str) -> str:
         if self.hs.config.email_riot_base_url:
             base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
         elif self.app_name == "Vector":
@@ -590,7 +686,7 @@ class Mailer:
             base_url = "https://matrix.to/#"
         return "%s/%s" % (base_url, room_id)
 
-    def make_notif_link(self, notif):
+    def make_notif_link(self, notif: Dict[str, str]) -> str:
         if self.hs.config.email_riot_base_url:
             return "%s/#/room/%s/%s" % (
                 self.hs.config.email_riot_base_url,
@@ -606,7 +702,9 @@ class Mailer:
         else:
             return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
 
-    def make_unsubscribe_link(self, user_id, app_id, email_address):
+    def make_unsubscribe_link(
+        self, user_id: str, app_id: str, email_address: str
+    ) -> str:
         params = {
             "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
             "app_id": app_id,
@@ -620,7 +718,16 @@ class Mailer:
         )
 
 
-def safe_markup(raw_html):
+def safe_markup(raw_html: str) -> jinja2.Markup:
+    """
+    Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
+
+    Args
+        raw_html: Unsafe HTML.
+
+    Returns:
+        A Markup object ready to safely use in a Jinja template.
+    """
     return jinja2.Markup(
         bleach.linkify(
             bleach.clean(
@@ -635,10 +742,15 @@ def safe_markup(raw_html):
     )
 
 
-def safe_text(raw_text):
+def safe_text(raw_text: str) -> jinja2.Markup:
     """
-    Process text: treat it as HTML but escape any tags (ie. just escape the
-    HTML) then linkify it.
+    Sanitise text (escape any HTML tags), and then linkify any bare URLs.
+
+    Args
+        raw_text: Unsafe text which might include HTML markup.
+
+    Returns:
+        A Markup object ready to safely use in a Jinja template.
     """
     return jinja2.Markup(
         bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
@@ -655,7 +767,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
     return ret
 
 
-def string_ordinal_total(s):
+def string_ordinal_total(s: str) -> int:
     tot = 0
     for c in s:
         tot += ord(c)
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index d8f4a453cd..04c2c1482c 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -15,8 +15,14 @@
 
 import logging
 import re
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.types import StateMap
+
+if TYPE_CHECKING:
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -28,32 +34,36 @@ ALL_ALONE = "Empty Room"
 
 
 async def calculate_room_name(
-    store,
-    room_state_ids,
-    user_id,
-    fallback_to_members=True,
-    fallback_to_single_member=True,
-):
+    store: "DataStore",
+    room_state_ids: StateMap[str],
+    user_id: str,
+    fallback_to_members: bool = True,
+    fallback_to_single_member: bool = True,
+) -> Optional[str]:
     """
     Works out a user-facing name for the given room as per Matrix
     spec recommendations.
     Does not yet support internationalisation.
     Args:
-        room_state: Dictionary of the room's state
+        store: The data store to query.
+        room_state_ids: Dictionary of the room's state IDs.
         user_id: The ID of the user to whom the room name is being presented
         fallback_to_members: If False, return None instead of generating a name
                              based on the room's members if the room has no
                              title or aliases.
+        fallback_to_single_member: If False, return None instead of generating a
+            name based on the user who invited this user to the room if the room
+            has no title or aliases.
 
     Returns:
-        (string or None) A human readable name for the room.
+        A human readable name for the room, if possible.
     """
     # does it have a name?
     if (EventTypes.Name, "") in room_state_ids:
         m_room_name = await store.get_event(
             room_state_ids[(EventTypes.Name, "")], allow_none=True
         )
-        if m_room_name and m_room_name.content and m_room_name.content["name"]:
+        if m_room_name and m_room_name.content and m_room_name.content.get("name"):
             return m_room_name.content["name"]
 
     # does it have a canonical alias?
@@ -64,15 +74,11 @@ async def calculate_room_name(
         if (
             canon_alias
             and canon_alias.content
-            and canon_alias.content["alias"]
+            and canon_alias.content.get("alias")
             and _looks_like_an_alias(canon_alias.content["alias"])
         ):
             return canon_alias.content["alias"]
 
-    # at this point we're going to need to search the state by all state keys
-    # for an event type, so rearrange the data structure
-    room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
-
     if not fallback_to_members:
         return None
 
@@ -84,7 +90,7 @@ async def calculate_room_name(
 
     if (
         my_member_event is not None
-        and my_member_event.content["membership"] == "invite"
+        and my_member_event.content.get("membership") == Membership.INVITE
     ):
         if (EventTypes.Member, my_member_event.sender) in room_state_ids:
             inviter_member_event = await store.get_event(
@@ -97,10 +103,14 @@ async def calculate_room_name(
                         name_from_member_event(inviter_member_event),
                     )
                 else:
-                    return
+                    return None
         else:
             return "Room Invite"
 
+    # at this point we're going to need to search the state by all state keys
+    # for an event type, so rearrange the data structure
+    room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
+
     # we're going to have to generate a name based on who's in the room,
     # so find out who is in the room that isn't the user.
     if EventTypes.Member in room_state_bytype_ids:
@@ -110,8 +120,8 @@ async def calculate_room_name(
         all_members = [
             ev
             for ev in member_events.values()
-            if ev.content["membership"] == "join"
-            or ev.content["membership"] == "invite"
+            if ev.content.get("membership") == Membership.JOIN
+            or ev.content.get("membership") == Membership.INVITE
         ]
         # Sort the member events oldest-first so the we name people in the
         # order the joined (it should at least be deterministic rather than
@@ -150,19 +160,19 @@ async def calculate_room_name(
         else:
             return ALL_ALONE
     elif len(other_members) == 1 and not fallback_to_single_member:
-        return
-    else:
-        return descriptor_from_member_events(other_members)
+        return None
 
+    return descriptor_from_member_events(other_members)
 
-def descriptor_from_member_events(member_events):
+
+def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
     """Get a description of the room based on the member events.
 
     Args:
-        member_events (Iterable[FrozenEvent])
+        member_events: The events of a room.
 
     Returns:
-        str
+        The room description
     """
 
     member_events = list(member_events)
@@ -183,22 +193,18 @@ def descriptor_from_member_events(member_events):
         )
 
 
-def name_from_member_event(member_event):
-    if (
-        member_event.content
-        and "displayname" in member_event.content
-        and member_event.content["displayname"]
-    ):
+def name_from_member_event(member_event: EventBase) -> str:
+    if member_event.content and member_event.content.get("displayname"):
         return member_event.content["displayname"]
     return member_event.state_key
 
 
-def _state_as_two_level_dict(state):
-    ret = {}
+def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
+    ret = {}  # type: Dict[str, Dict[str, str]]
     for k, v in state.items():
         ret.setdefault(k[0], {})[k[1]] = v
     return ret
 
 
-def _looks_like_an_alias(string):
+def _looks_like_an_alias(string: str) -> bool:
     return ALIAS_RE.match(string) is not None
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2ce9e444ab..ba1877adcd 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
 INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
 
 
-def _room_member_count(ev, condition, room_member_count):
+def _room_member_count(
+    ev: EventBase, condition: Dict[str, Any], room_member_count: int
+) -> bool:
     return _test_ineq_condition(condition, room_member_count)
 
 
-def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+def _sender_notification_permission(
+    ev: EventBase,
+    condition: Dict[str, Any],
+    sender_power_level: int,
+    power_levels: Dict[str, Union[int, Dict[str, int]]],
+) -> bool:
     notif_level_key = condition.get("key")
     if notif_level_key is None:
         return False
 
     notif_levels = power_levels.get("notifications", {})
+    assert isinstance(notif_levels, dict)
     room_notif_level = notif_levels.get(notif_level_key, 50)
 
     return sender_power_level >= room_notif_level
 
 
-def _test_ineq_condition(condition, number):
+def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
     if "is" not in condition:
         return False
     m = INEQUALITY_EXPR.match(condition["is"])
@@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
         event: EventBase,
         room_member_count: int,
         sender_power_level: int,
-        power_levels: dict,
+        power_levels: Dict[str, Union[int, Dict[str, int]]],
     ):
         self._event = event
         self._room_member_count = room_member_count
@@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
-    def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
+    def matches(
+        self, condition: Dict[str, Any], user_id: str, display_name: str
+    ) -> bool:
         if condition["kind"] == "event_match":
             return self._event_match(condition, user_id)
         elif condition["kind"] == "contains_display_name":
@@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
     return r"(^|\W)%s(\W|$)" % (r,)
 
 
-def _flatten_dict(d, prefix=[], result=None):
+def _flatten_dict(
+    d: Union[EventBase, dict],
+    prefix: Optional[List[str]] = None,
+    result: Optional[Dict[str, str]] = None,
+) -> Dict[str, str]:
+    if prefix is None:
+        prefix = []
     if result is None:
         result = {}
     for key, value in d.items():
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6e7c880dc0..df34103224 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,6 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Dict
+
+from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage import Storage
 from synapse.storage.databases.main import DataStore
@@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     return badge
 
 
-async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
+async def get_context_for_event(
+    storage: Storage, ev: EventBase, user_id: str
+) -> Dict[str, str]:
     ctx = {}
 
     room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 2a52e226e3..2aa7918fb4 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -14,25 +14,31 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Callable, Dict, Optional
 
+from synapse.push import Pusher, PusherConfig
 from synapse.push.emailpusher import EmailPusher
+from synapse.push.httppusher import HttpPusher
 from synapse.push.mailer import Mailer
 
-from .httppusher import HttpPusher
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class PusherFactory:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.config = hs.config
 
-        self.pusher_types = {"http": HttpPusher}
+        self.pusher_types = {
+            "http": HttpPusher
+        }  # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
 
         logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
         if hs.config.email_enable_notifs:
-            self.mailers = {}  # app_name -> Mailer
+            self.mailers = {}  # type: Dict[str, Mailer]
 
             self._notif_template_html = hs.config.email_notif_template_html
             self._notif_template_text = hs.config.email_notif_template_text
@@ -41,16 +47,18 @@ class PusherFactory:
 
             logger.info("defined email pusher type")
 
-    def create_pusher(self, pusherdict):
-        kind = pusherdict["kind"]
+    def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
+        kind = pusher_config.kind
         f = self.pusher_types.get(kind, None)
         if not f:
             return None
-        logger.debug("creating %s pusher for %r", kind, pusherdict)
-        return f(self.hs, pusherdict)
+        logger.debug("creating %s pusher for %r", kind, pusher_config)
+        return f(self.hs, pusher_config)
 
-    def _create_email_pusher(self, _hs, pusherdict):
-        app_name = self._app_name_from_pusherdict(pusherdict)
+    def _create_email_pusher(
+        self, _hs: "HomeServer", pusher_config: PusherConfig
+    ) -> EmailPusher:
+        app_name = self._app_name_from_pusherdict(pusher_config)
         mailer = self.mailers.get(app_name)
         if not mailer:
             mailer = Mailer(
@@ -60,10 +68,10 @@ class PusherFactory:
                 template_text=self._notif_template_text,
             )
             self.mailers[app_name] = mailer
-        return EmailPusher(self.hs, pusherdict, mailer)
+        return EmailPusher(self.hs, pusher_config, mailer)
 
-    def _app_name_from_pusherdict(self, pusherdict):
-        data = pusherdict["data"]
+    def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
+        data = pusher_config.data
 
         if isinstance(data, dict):
             brand = data.get("brand")
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f325964983..eed16dbfb5 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
 
 from prometheus_client import Gauge
 
@@ -23,11 +23,9 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
-from synapse.push import PusherConfigException
-from synapse.push.emailpusher import EmailPusher
-from synapse.push.httppusher import HttpPusher
+from synapse.push import Pusher, PusherConfig, PusherConfigException
 from synapse.push.pusher import PusherFactory
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
 
 if TYPE_CHECKING:
@@ -77,9 +75,9 @@ class PusherPool:
         self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
 
         # map from user id to app_id:pushkey to pusher
-        self.pushers = {}  # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
+        self.pushers = {}  # type: Dict[str, Dict[str, Pusher]]
 
-    def start(self):
+    def start(self) -> None:
         """Starts the pushers off in a background process.
         """
         if not self._should_start_pushers:
@@ -89,52 +87,53 @@ class PusherPool:
 
     async def add_pusher(
         self,
-        user_id,
-        access_token,
-        kind,
-        app_id,
-        app_display_name,
-        device_display_name,
-        pushkey,
-        lang,
-        data,
-        profile_tag="",
-    ):
+        user_id: str,
+        access_token: Optional[int],
+        kind: str,
+        app_id: str,
+        app_display_name: str,
+        device_display_name: str,
+        pushkey: str,
+        lang: Optional[str],
+        data: JsonDict,
+        profile_tag: str = "",
+    ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
         Returns:
-            EmailPusher|HttpPusher
+            The newly created pusher.
         """
 
         time_now_msec = self.clock.time_msec()
 
+        # create the pusher setting last_stream_ordering to the current maximum
+        # stream ordering, so it will process pushes from this point onwards.
+        last_stream_ordering = self.store.get_room_max_stream_ordering()
+
         # we try to create the pusher just to validate the config: it
         # will then get pulled out of the database,
         # recreated, added and started: this means we have only one
         # code path adding pushers.
         self.pusher_factory.create_pusher(
-            {
-                "id": None,
-                "user_name": user_id,
-                "kind": kind,
-                "app_id": app_id,
-                "app_display_name": app_display_name,
-                "device_display_name": device_display_name,
-                "pushkey": pushkey,
-                "ts": time_now_msec,
-                "lang": lang,
-                "data": data,
-                "last_stream_ordering": None,
-                "last_success": None,
-                "failing_since": None,
-            }
+            PusherConfig(
+                id=None,
+                user_name=user_id,
+                access_token=access_token,
+                profile_tag=profile_tag,
+                kind=kind,
+                app_id=app_id,
+                app_display_name=app_display_name,
+                device_display_name=device_display_name,
+                pushkey=pushkey,
+                ts=time_now_msec,
+                lang=lang,
+                data=data,
+                last_stream_ordering=last_stream_ordering,
+                last_success=None,
+                failing_since=None,
+            )
         )
 
-        # create the pusher setting last_stream_ordering to the current maximum
-        # stream ordering in event_push_actions, so it will process
-        # pushes from this point onwards.
-        last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
-
         await self.store.add_pusher(
             user_id=user_id,
             access_token=access_token,
@@ -154,43 +153,44 @@ class PusherPool:
         return pusher
 
     async def remove_pushers_by_app_id_and_pushkey_not_user(
-        self, app_id, pushkey, not_user_id
-    ):
+        self, app_id: str, pushkey: str, not_user_id: str
+    ) -> None:
         to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
         for p in to_remove:
-            if p["user_name"] != not_user_id:
+            if p.user_name != not_user_id:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     app_id,
                     pushkey,
-                    p["user_name"],
+                    p.user_name,
                 )
-                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
 
-    async def remove_pushers_by_access_token(self, user_id, access_tokens):
+    async def remove_pushers_by_access_token(
+        self, user_id: str, access_tokens: Iterable[int]
+    ) -> None:
         """Remove the pushers for a given user corresponding to a set of
         access_tokens.
 
         Args:
-            user_id (str): user to remove pushers for
-            access_tokens (Iterable[int]): access token *ids* to remove pushers
-                for
+            user_id: user to remove pushers for
+            access_tokens: access token *ids* to remove pushers for
         """
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
             return
 
         tokens = set(access_tokens)
         for p in await self.store.get_pushers_by_user_id(user_id):
-            if p["access_token"] in tokens:
+            if p.access_token in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
-                    p["app_id"],
-                    p["pushkey"],
-                    p["user_name"],
+                    p.app_id,
+                    p.pushkey,
+                    p.user_name,
                 )
-                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
 
-    def on_new_notifications(self, max_token: RoomStreamToken):
+    def on_new_notifications(self, max_token: RoomStreamToken) -> None:
         if not self.pushers:
             # nothing to do here.
             return
@@ -209,7 +209,7 @@ class PusherPool:
         self._on_new_notifications(max_token)
 
     @wrap_as_background_process("on_new_notifications")
-    async def _on_new_notifications(self, max_token: RoomStreamToken):
+    async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
         # We just use the minimum stream ordering and ignore the vector clock
         # component. This is safe to do as long as we *always* ignore the vector
         # clock components.
@@ -239,7 +239,9 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_notifications")
 
-    async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+    async def on_new_receipts(
+        self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
+    ) -> None:
         if not self.pushers:
             # nothing to do here.
             return
@@ -267,28 +269,30 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
-    async def start_pusher_by_id(self, app_id, pushkey, user_id):
+    async def start_pusher_by_id(
+        self, app_id: str, pushkey: str, user_id: str
+    ) -> Optional[Pusher]:
         """Look up the details for the given pusher, and start it
 
         Returns:
-            EmailPusher|HttpPusher|None: The pusher started, if any
+            The pusher started, if any
         """
         if not self._should_start_pushers:
-            return
+            return None
 
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
-            return
+            return None
 
         resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
-        pusher_dict = None
+        pusher_config = None
         for r in resultlist:
-            if r["user_name"] == user_id:
-                pusher_dict = r
+            if r.user_name == user_id:
+                pusher_config = r
 
         pusher = None
-        if pusher_dict:
-            pusher = await self._start_pusher(pusher_dict)
+        if pusher_config:
+            pusher = await self._start_pusher(pusher_config)
 
         return pusher
 
@@ -303,44 +307,44 @@ class PusherPool:
 
         logger.info("Started pushers")
 
-    async def _start_pusher(self, pusherdict):
+    async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
         """Start the given pusher
 
         Args:
-            pusherdict (dict): dict with the values pulled from the db table
+            pusher_config: The pusher configuration with the values pulled from the db table
 
         Returns:
-            EmailPusher|HttpPusher
+            The newly created pusher or None.
         """
         if not self._pusher_shard_config.should_handle(
-            self._instance_name, pusherdict["user_name"]
+            self._instance_name, pusher_config.user_name
         ):
-            return
+            return None
 
         try:
-            p = self.pusher_factory.create_pusher(pusherdict)
+            p = self.pusher_factory.create_pusher(pusher_config)
         except PusherConfigException as e:
             logger.warning(
                 "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
-                pusherdict["id"],
-                pusherdict.get("user_name"),
-                pusherdict.get("app_id"),
-                pusherdict.get("pushkey"),
+                pusher_config.id,
+                pusher_config.user_name,
+                pusher_config.app_id,
+                pusher_config.pushkey,
                 e,
             )
-            return
+            return None
         except Exception:
             logger.exception(
-                "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+                "Couldn't start pusher id %i: caught Exception", pusher_config.id,
             )
-            return
+            return None
 
         if not p:
-            return
+            return None
 
-        appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+        appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
 
-        byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+        byuser = self.pushers.setdefault(pusher_config.user_name, {})
         if appid_pushkey in byuser:
             byuser[appid_pushkey].on_stop()
         byuser[appid_pushkey] = p
@@ -350,8 +354,8 @@ class PusherPool:
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
-        user_id = pusherdict["user_name"]
-        last_stream_ordering = pusherdict["last_stream_ordering"]
+        user_id = pusher_config.user_name
+        last_stream_ordering = pusher_config.last_stream_ordering
         if last_stream_ordering:
             have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
                 user_id, last_stream_ordering
@@ -365,7 +369,7 @@ class PusherPool:
 
         return p
 
-    async def remove_pusher(self, app_id, pushkey, user_id):
+    async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
         appid_pushkey = "%s:%s" % (app_id, pushkey)
 
         byuser = self.pushers.get(user_id, {})
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c97e0df1f5..bfd46a3730 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -86,8 +86,8 @@ REQUIREMENTS = [
 
 CONDITIONAL_REQUIREMENTS = {
     "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
-    # we use execute_batch, which arrived in psycopg 2.7.
-    "postgres": ["psycopg2>=2.7"],
+    # we use execute_values with the fetch param, which arrived in psycopg 2.8.
+    "postgres": ["psycopg2>=2.8"],
     # ACME support is required to provision TLS certificates from authorities
     # that use the protocol, such as Let's Encrypt.
     "acme": [
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a84a064c8d..dd527e807f 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -15,6 +15,7 @@
 
 from synapse.http.server import JsonResource
 from synapse.replication.http import (
+    account_data,
     devices,
     federation,
     login,
@@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
         presence.register_servlets(hs, self)
         membership.register_servlets(hs, self)
         streams.register_servlets(hs, self)
+        account_data.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2b3972cb14..288727a566 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         assert self.METHOD in ("PUT", "POST", "GET")
 
+        self._replication_secret = None
+        if hs.config.worker.worker_replication_secret:
+            self._replication_secret = hs.config.worker.worker_replication_secret
+
+    def _check_auth(self, request) -> None:
+        # Get the authorization header.
+        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+        if len(auth_headers) > 1:
+            raise RuntimeError("Too many Authorization headers.")
+        parts = auth_headers[0].split(b" ")
+        if parts[0] == b"Bearer" and len(parts) == 2:
+            received_secret = parts[1].decode("ascii")
+            if self._replication_secret == received_secret:
+                # Success!
+                return
+
+        raise RuntimeError("Invalid Authorization header.")
+
     @abc.abstractmethod
     async def _serialize_payload(**kwargs):
         """Static method that is called when creating a request.
@@ -150,9 +169,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
 
+        replication_secret = None
+        if hs.config.worker.worker_replication_secret:
+            replication_secret = hs.config.worker.worker_replication_secret.encode(
+                "ascii"
+            )
+
         @trace(opname="outgoing_replication_request")
         @outgoing_gauge.track_inprogress()
-        async def send_request(instance_name="master", **kwargs):
+        async def send_request(*, instance_name="master", **kwargs):
             if instance_name == local_instance_name:
                 raise Exception("Trying to send HTTP request to self")
             if instance_name == "master":
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 # the master, and so whether we should clean up or not.
                 while True:
                     headers = {}  # type: Dict[bytes, List[bytes]]
+                    # Add an authorization header, if configured.
+                    if replication_secret:
+                        headers[b"Authorization"] = [b"Bearer " + replication_secret]
                     inject_active_span_byte_dict(headers, None, check_destination=False)
                     try:
                         result = await request_func(uri, data, headers=headers)
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         """
 
         url_args = list(self.PATH_ARGS)
-        handler = self._handle_request
         method = self.METHOD
 
         if self.CACHE:
-            handler = self._cached_handler  # type: ignore
             url_args.append("txn_id")
 
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
         pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
         http_server.register_paths(
-            method, [pattern], handler, self.__class__.__name__,
+            method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
         )
 
-    def _cached_handler(self, request, txn_id, **kwargs):
+    def _check_auth_and_handle(self, request, **kwargs):
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
         otherwise calls `_handle_request` and caches its response.
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         # We just use the txn_id here, but we probably also want to use the
         # other PATH_ARGS as well.
 
-        assert self.CACHE
+        # Check the authorization headers before handling the request.
+        if self._replication_secret:
+            self._check_auth(request)
+
+        if self.CACHE:
+            txn_id = kwargs.pop("txn_id")
+
+            return self.response_cache.wrap(
+                txn_id, self._handle_request, request, **kwargs
+            )
 
-        return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
+        return self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
new file mode 100644
index 0000000000..52d32528ee
--- /dev/null
+++ b/synapse/replication/http/account_data.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+    """Add user account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_user_account_data/:user_id/:type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_user_account_data"
+    PATH_ARGS = ("user_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_for_user(
+            user_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+    """Add room account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_room_account_data"
+    PATH_ARGS = ("user_id", "room_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_to_room(
+            user_id, room_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddTagRestServlet(ReplicationEndpoint):
+    """Add tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_tag"
+    PATH_ARGS = ("user_id", "room_id", "tag")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_tag_to_room(
+            user_id, room_id, tag, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
+    """Remove tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
+
+        {}
+
+    """
+
+    NAME = "remove_tag"
+    PATH_ARGS = (
+        "user_id",
+        "room_id",
+        "tag",
+    )
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag):
+
+        return {}
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+def register_servlets(hs, http_server):
+    ReplicationUserAccountDataRestServlet(hs).register(http_server)
+    ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+    ReplicationAddTagRestServlet(hs).register(http_server)
+    ReplicationRemoveTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 4c81e2d784..36071feb36 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
-    async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+    async def _serialize_payload(
+        user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
+    ):
         """
         Args:
             device_id (str|None): Device ID to use, if None a new one is
@@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             "device_id": device_id,
             "initial_display_name": initial_display_name,
             "is_guest": is_guest,
+            "is_appservice_ghost": is_appservice_ghost,
         }
 
     async def _handle_request(self, request, user_id):
@@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         device_id = content["device_id"]
         initial_display_name = content["initial_display_name"]
         is_guest = content["is_guest"]
+        is_appservice_ghost = content["is_appservice_ghost"]
 
         device_id, access_token = await self.registration_handler.register_device(
-            user_id, device_id, initial_display_name, is_guest
+            user_id,
+            device_id,
+            initial_display_name,
+            is_guest,
+            is_appservice_ghost=is_appservice_ghost,
         )
 
         return 200, {"device_id": device_id, "access_token": access_token}
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d0089fe06c..693c9ab901 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d68..0d39a93ed2 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List, Optional, Tuple
 
+from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import _load_current_id
 
 
 class SlavedIdTracker:
-    def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+    def __init__(
+        self,
+        db_conn: Connection,
+        table: str,
+        column: str,
+        extra_tables: Optional[List[Tuple[str, str]]] = None,
+        step: int = 1,
+    ):
         self.step = step
         self._current = _load_current_id(db_conn, table, column, step)
-        for table, column in extra_tables:
-            self.advance(None, _load_current_id(db_conn, table, column))
+        if extra_tables:
+            for table, column in extra_tables:
+                self.advance(None, _load_current_id(db_conn, table, column))
 
-    def advance(self, instance_name, new_id):
+    def advance(self, instance_name: Optional[str], new_id: int):
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         """
 
         Returns:
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 4268565fc8..21afe5f155 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,47 +15,9 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.databases.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = SlavedIdTracker(
-            db_conn,
-            "account_data",
-            "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self):
-        return self._account_data_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                if not row.room_id:
-                    self.get_global_account_data_by_type_for_user.invalidate(
-                        (row.data_type, row.user_id)
-                    )
-                self.get_account_data_for_user.invalidate((row.user_id,))
-                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
-                self.get_account_data_for_room_and_type.invalidate(
-                    (row.user_id, row.room_id, row.data_type)
-                )
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 5b045bed02..1260f6d141 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -14,46 +14,8 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import ToDeviceStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-        self._device_inbox_id_gen = SlavedIdTracker(
-            db_conn, "device_inbox", "stream_id"
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-
-        self._last_device_delete_cache = ExpiringCache(
-            cache_name="last_device_delete_cache",
-            clock=self._clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-        )
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ToDeviceStream.NAME:
-            self._device_inbox_id_gen.advance(instance_name, token)
-            for row in rows:
-                if row.entity.startswith("@"):
-                    self._device_inbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-                else:
-                    self._device_federation_outbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730ba8..045bd014da 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import TYPE_CHECKING
 
 from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
-        self._pushers_id_gen = SlavedIdTracker(
+        self._pushers_id_gen = SlavedIdTracker(  # type: ignore
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
 
-    def get_pushers_stream_token(self):
+    def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token, rows
+    ) -> None:
         if stream_name == PushersStream.NAME:
-            self._pushers_id_gen.advance(instance_name, token)
+            self._pushers_id_gen.advance(instance_name, token)  # type: ignore
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6195917376..3dfdd9961d 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,43 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = SlavedIdTracker(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
-
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
-        self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
-        self.get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
-        )
-        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
-        self.get_receipts_for_room.invalidate((room_id, receipt_type))
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ReceiptsStream.NAME:
-            self._receipts_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.invalidate_caches_for_receipt(
-                    row.room_id, row.receipt_type, row.user_id
-                )
-                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
-
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+set_counter = Counter(
+    "synapse_external_cache_set",
+    "Number of times we set a cache",
+    labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+    "synapse_external_cache_get",
+    "Number of times we get a cache",
+    labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+    """A cache backed by an external Redis. Does nothing if no Redis is
+    configured.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._redis_connection = hs.get_outbound_redis_connection()
+
+    def _get_redis_key(self, cache_name: str, key: str) -> str:
+        return "cache_v1:%s:%s" % (cache_name, key)
+
+    def is_enabled(self) -> bool:
+        """Whether the external cache is used or not.
+
+        It's safe to use the cache when this returns false, the methods will
+        just no-op, but the function is useful to avoid doing unnecessary work.
+        """
+        return self._redis_connection is not None
+
+    async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+        """Add the key/value to the named cache, with the expiry time given.
+        """
+
+        if self._redis_connection is None:
+            return
+
+        set_counter.labels(cache_name).inc()
+
+        # txredisapi requires the value to be string, bytes or numbers, so we
+        # encode stuff in JSON.
+        encoded_value = json_encoder.encode(value)
+
+        logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+        return await make_deferred_yieldable(
+            self._redis_connection.set(
+                self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+            )
+        )
+
+    async def get(self, cache_name: str, key: str) -> Optional[Any]:
+        """Look up a key/value in the named cache.
+        """
+
+        if self._redis_connection is None:
+            return None
+
+        result = await make_deferred_yieldable(
+            self._redis_connection.get(self._get_redis_key(cache_name, key))
+        )
+
+        logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+        get_counter.labels(cache_name, result is not None).inc()
+
+        if not result:
+            return None
+
+        # For some reason the integers get magically converted back to integers
+        if isinstance(result, int):
+            return result
+
+        return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 95e5502bf2..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Dict,
@@ -51,14 +52,21 @@ from synapse.replication.tcp.commands import (
 from synapse.replication.tcp.protocol import AbstractConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
+    AccountDataStream,
     BackfillStream,
     CachesStream,
     EventsStream,
     FederationStream,
+    ReceiptsStream,
     Stream,
+    TagAccountDataStream,
+    ToDeviceStream,
     TypingStream,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -84,7 +92,7 @@ class ReplicationCommandHandler:
     back out to connections.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._replication_data_handler = hs.get_replication_data_handler()
         self._presence_handler = hs.get_presence_handler()
         self._store = hs.get_datastore()
@@ -115,6 +123,14 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, ToDeviceStream):
+                # Only add ToDeviceStream as a source on instances in charge of
+                # sending to device messages.
+                if hs.get_instance_name() in hs.config.worker.writers.to_device:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             if isinstance(stream, TypingStream):
                 # Only add TypingStream as a source on the instance in charge of
                 # typing.
@@ -123,6 +139,22 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+                # Only add AccountDataStream and TagAccountDataStream as a source on the
+                # instance in charge of account_data persistence.
+                if hs.get_instance_name() in hs.config.worker.writers.account_data:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
+            if isinstance(stream, ReceiptsStream):
+                # Only add ReceiptsStream as a source on the instance in charge of
+                # receipts.
+                if hs.get_instance_name() in hs.config.worker.writers.receipts:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
@@ -254,13 +286,6 @@ class ReplicationCommandHandler:
         if hs.config.redis.redis_enabled:
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
-                lazyConnection,
-            )
-
-            logger.info(
-                "Connecting to redis (host=%r port=%r)",
-                hs.config.redis_host,
-                hs.config.redis_port,
             )
 
             # First let's ensure that we have a ReplicationStreamer started.
@@ -271,13 +296,7 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = lazyConnection(
-                reactor=hs.get_reactor(),
-                host=hs.config.redis_host,
-                port=hs.config.redis_port,
-                password=hs.config.redis.redis_password,
-                reconnect=True,
-            )
+            outbound_redis_connection = hs.get_outbound_redis_connection()
 
             # Now create the factory/connection for the subscription stream.
             self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
         ctx_name = "replication-conn-%s" % self.conn_id
-        self._logging_context = BackgroundProcessLoggingContext(ctx_name)
-        self._logging_context.request = ctx_name
+        self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
 
 import txredisapi
 
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import (
     BackgroundProcessLoggingContext,
     run_as_background_process,
+    wrap_as_background_process,
 )
 from synapse.replication.tcp.commands import (
     Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     immediately after initialisation.
 
     Attributes:
-        handler: The command handler to handle incoming commands.
-        stream_name: The *redis* stream name to subscribe to and publish from
-            (not anything to do with Synapse replication streams).
-        outbound_redis_connection: The connection to redis to use to send
+        synapse_handler: The command handler to handle incoming commands.
+        synapse_stream_name: The *redis* stream name to subscribe to and publish
+            from (not anything to do with Synapse replication streams).
+        synapse_outbound_redis_connection: The connection to redis to use to send
             commands.
     """
 
-    handler = None  # type: ReplicationCommandHandler
-    stream_name = None  # type: str
-    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+    synapse_handler = None  # type: ReplicationCommandHandler
+    synapse_stream_name = None  # type: str
+    synapse_outbound_redis_connection = None  # type: txredisapi.RedisProtocol
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # it's important to make sure that we only send the REPLICATE command once we
         # have successfully subscribed to the stream - otherwise we might miss the
         # POSITION response sent back by the other end.
-        logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
-        await make_deferred_yieldable(self.subscribe(self.stream_name))
+        logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+        await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
         logger.info(
             "Successfully subscribed to redis stream, sending REPLICATE command"
         )
-        self.handler.new_connection(self)
+        self.synapse_handler.new_connection(self)
         await self._async_send_command(ReplicateCommand())
         logger.info("REPLICATE successfully sent")
 
         # We send out our positions when there is a new connection in case the
         # other side missed updates. We do this for Redis connections as the
         # otherside won't know we've connected and so won't issue a REPLICATE.
-        self.handler.send_positions_to_connection(self)
+        self.synapse_handler.send_positions_to_connection(self)
 
     def messageReceived(self, pattern: str, channel: str, message: str):
         """Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             cmd: received command
         """
 
-        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
         if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
             return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")
         super().connectionLost(reason)
-        self.handler.lost_connection(self)
+        self.synapse_handler.lost_connection(self)
 
         # mark the logging context as finished
         self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
         await make_deferred_yieldable(
-            self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+            self.synapse_outbound_redis_connection.publish(
+                self.synapse_stream_name, encoded_string
+            )
+        )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+    """A subclass of RedisFactory that periodically sends pings to ensure that
+    we detect dead connections.
+    """
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        uuid: str,
+        dbid: Optional[int],
+        poolsize: int,
+        isLazy: bool = False,
+        handler: Type = txredisapi.ConnectionHandler,
+        charset: str = "utf-8",
+        password: Optional[str] = None,
+        replyTimeout: int = 30,
+        convertNumbers: Optional[int] = True,
+    ):
+        super().__init__(
+            uuid=uuid,
+            dbid=dbid,
+            poolsize=poolsize,
+            isLazy=isLazy,
+            handler=handler,
+            charset=charset,
+            password=password,
+            replyTimeout=replyTimeout,
+            convertNumbers=convertNumbers,
         )
 
+        hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+    @wrap_as_background_process("redis_ping")
+    async def _send_ping(self):
+        for connection in self.pool:
+            try:
+                await make_deferred_yieldable(connection.ping())
+            except Exception:
+                logger.warning("Failed to send ping to a redis connection")
 
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     """This is a reconnecting factory that connects to redis and immediately
     subscribes to a stream.
 
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
     ):
 
-        super().__init__()
-
-        # This sets the password on the RedisFactory base class (as
-        # SubscriberFactory constructor doesn't pass it through).
-        self.password = hs.config.redis.redis_password
+        super().__init__(
+            hs,
+            uuid="subscriber",
+            dbid=None,
+            poolsize=1,
+            replyTimeout=30,
+            password=hs.config.redis.redis_password,
+        )
 
-        self.handler = hs.get_tcp_replication()
-        self.stream_name = hs.hostname
+        self.synapse_handler = hs.get_tcp_replication()
+        self.synapse_stream_name = hs.hostname
 
-        self.outbound_redis_connection = outbound_redis_connection
+        self.synapse_outbound_redis_connection = outbound_redis_connection
 
     def buildProtocol(self, addr):
-        p = super().buildProtocol(addr)  # type: RedisSubscriber
+        p = super().buildProtocol(addr)
+        p = cast(RedisSubscriber, p)
 
         # We do this here rather than add to the constructor of `RedisSubcriber`
         # as to do so would involve overriding `buildProtocol` entirely, however
         # the base method does some other things than just instantiating the
         # protocol.
-        p.handler = self.handler
-        p.outbound_redis_connection = self.outbound_redis_connection
-        p.stream_name = self.stream_name
-        p.password = self.password
+        p.synapse_handler = self.synapse_handler
+        p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+        p.synapse_stream_name = self.synapse_stream_name
 
         return p
 
 
 def lazyConnection(
-    reactor,
+    hs: "HomeServer",
     host: str = "localhost",
     port: int = 6379,
     dbid: Optional[int] = None,
     reconnect: bool = True,
-    charset: str = "utf-8",
     password: Optional[str] = None,
-    connectTimeout: Optional[int] = None,
-    replyTimeout: Optional[int] = None,
-    convertNumbers: bool = True,
+    replyTimeout: int = 30,
 ) -> txredisapi.RedisProtocol:
-    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
-    reactor.
+    """Creates a connection to Redis that is lazily set up and reconnects if the
+    connections is lost.
     """
 
-    isLazy = True
-    poolsize = 1
-
     uuid = "%s:%d" % (host, port)
-    factory = txredisapi.RedisFactory(
-        uuid,
-        dbid,
-        poolsize,
-        isLazy,
-        txredisapi.ConnectionHandler,
-        charset,
-        password,
-        replyTimeout,
-        convertNumbers,
+    factory = SynapseRedisFactory(
+        hs,
+        uuid=uuid,
+        dbid=dbid,
+        poolsize=1,
+        isLazy=True,
+        handler=txredisapi.ConnectionHandler,
+        password=password,
+        replyTimeout=replyTimeout,
     )
     factory.continueTrying = reconnect
-    for x in range(poolsize):
-        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    reactor = hs.get_reactor()
+    reactor.connectTCP(host, port, factory, 30)
 
     return factory.handler
diff --git a/synapse/res/templates/notif.html b/synapse/res/templates/notif.html
index 6d76064d13..0aaef97df8 100644
--- a/synapse/res/templates/notif.html
+++ b/synapse/res/templates/notif.html
@@ -29,7 +29,7 @@
                         {{ message.body_text_html }}
                     {%- elif message.msgtype == "m.notice" %}
                         {{ message.body_text_html }}
-                    {%- elif message.msgtype == "m.image" %}
+                    {%- elif message.msgtype == "m.image" and message.image_url %}
                         <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
                     {%- elif message.msgtype == "m.file" %}
                         <span class="filename">{{ message.body_text_plain }}</span>
diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css
new file mode 100644
index 0000000000..46b309ea4e
--- /dev/null
+++ b/synapse/res/templates/sso.css
@@ -0,0 +1,88 @@
+body {
+  font-family: "Inter", "Helvetica", "Arial", sans-serif;
+  font-size: 14px;
+  color: #17191C;
+}
+
+header {
+  max-width: 480px;
+  width: 100%;
+  margin: 24px auto;
+  text-align: center;
+}
+
+header p {
+  color: #737D8C;
+  line-height: 24px;
+}
+
+h1 {
+  font-size: 24px;
+}
+
+.error_page h1 {
+  color: #FE2928;
+}
+
+h2 {
+  font-size: 14px;
+}
+
+h2 img {
+  vertical-align: middle;
+  margin-right: 8px;
+  width: 24px;
+  height: 24px;
+}
+
+label {
+  cursor: pointer;
+}
+
+main {
+  max-width: 360px;
+  width: 100%;
+  margin: 24px auto;
+}
+
+.primary-button {
+  border: none;
+  text-decoration: none;
+  padding: 12px;
+  color: white;
+  background-color: #418DED;
+  font-weight: bold;
+  display: block;
+  border-radius: 12px;
+  width: 100%;
+  box-sizing: border-box;
+  margin: 16px 0;
+  cursor: pointer;
+  text-align: center;
+}
+
+.profile {
+  display: flex;
+  justify-content: center;
+  margin: 24px 0;
+}
+
+.profile .avatar {
+  width: 36px;
+  height: 36px;
+  border-radius: 100%;
+  display: block;
+  margin-right: 8px;
+}
+
+.profile .display-name {
+  font-weight: bold;
+  margin-bottom: 4px;
+}
+.profile .user-id {
+  color: #737D8C;
+}
+
+.profile .display-name, .profile .user-id {
+  line-height: 18px;
+}
diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index 4eb8db9fb4..50a0979c2f 100644
--- a/synapse/res/templates/sso_account_deactivated.html
+++ b/synapse/res/templates/sso_account_deactivated.html
@@ -1,10 +1,24 @@
 <!DOCTYPE html>
 <html lang="en">
-<head>
-    <meta charset="UTF-8">
-    <title>SSO account deactivated</title>
-</head>
-    <body>
-        <p>This account has been deactivated.</p>
+    <head>
+        <meta charset="UTF-8">
+        <title>SSO account deactivated</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+    </head>
+    <body class="error_page">
+        <header>
+            <h1>Your account has been deactivated</h1>
+            <p>
+                <strong>No account found</strong>
+            </p>
+            <p>
+                Your account might have been deactivated by the server administrator.
+                You can either try to create a new account or contact the server’s
+                administrator.
+            </p>
+        </header>
     </body>
 </html>
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
new file mode 100644
index 0000000000..36850a2d6a
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -0,0 +1,161 @@
+<!DOCTYPE html>
+<html lang="en">
+  <head>
+    <title>Synapse Login</title>
+    <meta charset="utf-8">
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <style type="text/css">
+      {% include "sso.css" without context %}
+
+      .username_input {
+        display: flex;
+        border: 2px solid #418DED;
+        border-radius: 8px;
+        padding: 12px;
+        position: relative;
+        margin: 16px 0;
+        align-items: center;
+        font-size: 12px;
+      }
+
+      .username_input.invalid {
+        border-color: #FE2928;
+      }
+
+      .username_input.invalid input, .username_input.invalid label {
+        color: #FE2928;
+      }
+
+      .username_input div, .username_input input {
+        line-height: 18px;
+        font-size: 14px;
+      }
+
+      .username_input label {
+        position: absolute;
+        top: -8px;
+        left: 14px;
+        font-size: 80%;
+        background: white;
+        padding: 2px;
+      }
+
+      .username_input input {
+        flex: 1;
+        display: block;
+        min-width: 0;
+        border: none;
+      }
+
+      .username_input div {
+        color: #8D99A5;
+      }
+
+      .idp-pick-details {
+        border: 1px solid #E9ECF1;
+        border-radius: 8px;
+        margin: 24px 0;
+      }
+
+      .idp-pick-details h2 {
+        margin: 0;
+        padding: 8px 12px;
+      }
+
+      .idp-pick-details .idp-detail {
+        border-top: 1px solid #E9ECF1;
+        padding: 12px;
+      }
+      .idp-pick-details .check-row {
+        display: flex;
+        align-items: center;
+      }
+
+      .idp-pick-details .check-row .name {
+        flex: 1;
+      }
+
+      .idp-pick-details .use, .idp-pick-details .idp-value {
+        color: #737D8C;
+      }
+
+      .idp-pick-details .idp-value {
+        margin: 0;
+        margin-top: 8px;
+      }
+
+      .idp-pick-details .avatar {
+        width: 53px;
+        height: 53px;
+        border-radius: 100%;
+        display: block;
+        margin-top: 8px;
+      }
+
+      output {
+        padding: 0 14px;
+        display: block;
+      }
+
+      output.error {
+        color: #FE2928;
+      }
+    </style>
+  </head>
+  <body>
+    <header>
+      <h1>Your account is nearly ready</h1>
+      <p>Check your details before creating an account on {{ server_name }}</p>
+    </header>
+    <main>
+      <form method="post" class="form__input" id="form">
+        <div class="username_input" id="username_input">
+          <label for="field-username">Username</label>
+          <div class="prefix">@</div>
+          <input type="text" name="username" id="field-username" autofocus>
+          <div class="postfix">:{{ server_name }}</div>
+        </div>
+        <output for="username_input" id="field-username-output"></output>
+        <input type="submit" value="Continue" class="primary-button">
+        {% if user_attributes %}
+        <section class="idp-pick-details">
+          <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2>
+          {% if user_attributes.avatar_url %}
+          <div class="idp-detail idp-avatar">
+            <div class="check-row">
+              <label for="idp-avatar" class="name">Avatar</label>
+              <label for="idp-avatar" class="use">Use</label>
+              <input type="checkbox" name="use_avatar" id="idp-avatar" value="true" checked>
+            </div>
+            <img src="{{ user_attributes.avatar_url }}" class="avatar" />
+          </div>
+          {% endif %}
+          {% if user_attributes.display_name %}
+          <div class="idp-detail">
+            <div class="check-row">
+              <label for="idp-displayname" class="name">Display name</label>
+              <label for="idp-displayname" class="use">Use</label>
+              <input type="checkbox" name="use_display_name" id="idp-displayname" value="true" checked>
+            </div>
+            <p class="idp-value">{{ user_attributes.display_name }}</p>
+          </div>
+          {% endif %}
+          {% for email in user_attributes.emails %}
+          <div class="idp-detail">
+            <div class="check-row">
+              <label for="idp-email{{ loop.index }}" class="name">E-mail</label>
+              <label for="idp-email{{ loop.index }}" class="use">Use</label>
+              <input type="checkbox" name="use_email" id="idp-email{{ loop.index }}" value="{{ email }}" checked>
+            </div>
+            <p class="idp-value">{{ email }}</p>
+          </div>
+          {% endfor %}
+        </section>
+        {% endif %}
+      </form>
+    </main>
+    <script type="text/javascript">
+      {% include "sso_auth_account_details.js" without context %}
+    </script>
+  </body>
+</html>
diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js
new file mode 100644
index 0000000000..3c45df9078
--- /dev/null
+++ b/synapse/res/templates/sso_auth_account_details.js
@@ -0,0 +1,116 @@
+const usernameField = document.getElementById("field-username");
+const usernameOutput = document.getElementById("field-username-output");
+const form = document.getElementById("form");
+
+// needed to validate on change event when no input was changed
+let needsValidation = true;
+let isValid = false;
+
+function throttle(fn, wait) {
+    let timeout;
+    const throttleFn = function() {
+        const args = Array.from(arguments);
+        if (timeout) {
+            clearTimeout(timeout);
+        }
+        timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait);
+    };
+    throttleFn.cancelQueued = function() {
+        clearTimeout(timeout);
+    };
+    return throttleFn;
+}
+
+function checkUsernameAvailable(username) {
+    let check_uri = 'check?username=' + encodeURIComponent(username);
+    return fetch(check_uri, {
+        // include the cookie
+        "credentials": "same-origin",
+    }).then(function(response) {
+        if(!response.ok) {
+            // for non-200 responses, raise the body of the response as an exception
+            return response.text().then((text) => { throw new Error(text); });
+        } else {
+            return response.json();
+        }
+    }).then(function(json) {
+        if(json.error) {
+            return {message: json.error};
+        } else if(json.available) {
+            return {available: true};
+        } else {
+            return {message: username + " is not available, please choose another."};
+        }
+    });
+}
+
+const allowedUsernameCharacters = new RegExp("^[a-z0-9\\.\\_\\-\\/\\=]+$");
+const allowedCharactersString = "lowercase letters, digits, ., _, -, /, =";
+
+function reportError(error) {
+    throttledCheckUsernameAvailable.cancelQueued();
+    usernameOutput.innerText = error;
+    usernameOutput.classList.add("error");
+    usernameField.parentElement.classList.add("invalid");
+    usernameField.focus();
+}
+
+function validateUsername(username) {
+    isValid = false;
+    needsValidation = false;
+    usernameOutput.innerText = "";
+    usernameField.parentElement.classList.remove("invalid");
+    usernameOutput.classList.remove("error");
+    if (!username) {
+        return reportError("Please provide a username");
+    }
+    if (username.length > 255) {
+        return reportError("Too long, please choose something shorter");
+    }
+    if (!allowedUsernameCharacters.test(username)) {
+        return reportError("Invalid username, please only use " + allowedCharactersString);
+    }
+    usernameOutput.innerText = "Checking if username is available …";
+    throttledCheckUsernameAvailable(username);
+}
+
+const throttledCheckUsernameAvailable = throttle(function(username) {
+    const handleError = function(err) {
+        // don't prevent form submission on error
+        usernameOutput.innerText = "";
+        isValid = true;
+    };
+    try {
+        checkUsernameAvailable(username).then(function(result) {
+            if (!result.available) {
+                reportError(result.message);
+            } else {
+                isValid = true;
+                usernameOutput.innerText = "";
+            }
+        }, handleError);
+    } catch (err) {
+        handleError(err);
+    }
+}, 500);
+
+form.addEventListener("submit", function(evt) {
+    if (needsValidation) {
+        validateUsername(usernameField.value);
+        evt.preventDefault();
+        return;
+    }
+    if (!isValid) {
+        evt.preventDefault();
+        usernameField.focus();
+        return;
+    }
+});
+usernameField.addEventListener("input", function(evt) {
+    validateUsername(usernameField.value);
+});
+usernameField.addEventListener("change", function(evt) {
+    if (needsValidation) {
+        validateUsername(usernameField.value);
+    }
+});
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
new file mode 100644
index 0000000000..c9bd4bef20
--- /dev/null
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -0,0 +1,25 @@
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <title>Authentication failed</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+    </head>
+    <body class="error_page">
+        <header>
+            <h1>That doesn't look right</h1>
+            <p>
+                <strong>We were unable to validate your {{ server_name }} account</strong>
+                via single&nbsp;sign&#8209;on&nbsp;(SSO), because the SSO Identity
+                Provider returned different details than when you logged in.
+            </p>
+            <p>
+                Try the operation again, and ensure that you use the same details on
+                the Identity Provider as when you log into your account.
+            </p>
+        </header>
+    </body>
+</html>
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index 0d9de9d465..2099c2f1f8 100644
--- a/synapse/res/templates/sso_auth_confirm.html
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -1,14 +1,28 @@
-<html>
-<head>
-    <title>Authentication</title>
-</head>
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <title>Authentication</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+    </head>
     <body>
-        <div>
+        <header>
+            <h1>Confirm it's you to continue</h1>
             <p>
-                A client is trying to {{ description | e }}. To confirm this action,
-                <a href="{{ redirect_url | e }}">re-authenticate with single sign-on</a>.
-                If you did not expect this, your account may be compromised!
+                A client is trying to {{ description }}. To confirm this action
+                re-authorize your account with single sign-on.
             </p>
-        </div>
+            <p><strong>
+                If you did not expect this, your account may be compromised.
+            </strong></p>
+        </header>
+        <main>
+            <a href="{{ redirect_url }}" class="primary-button">
+                Continue with {{ idp.idp_name }}
+            </a>
+        </main>
     </body>
 </html>
diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html
index 03f1419467..3b975d7219 100644
--- a/synapse/res/templates/sso_auth_success.html
+++ b/synapse/res/templates/sso_auth_success.html
@@ -1,18 +1,27 @@
-<html>
-<head>
-    <title>Authentication Successful</title>
-    <script>
-    if (window.onAuthDone) {
-        window.onAuthDone();
-    } else if (window.opener && window.opener.postMessage) {
-        window.opener.postMessage("authDone", "*");
-    }
-    </script>
-</head>
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <title>Authentication successful</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+        </style>
+        <script>
+            if (window.onAuthDone) {
+                window.onAuthDone();
+            } else if (window.opener && window.opener.postMessage) {
+                window.opener.postMessage("authDone", "*");
+            }
+        </script>
+    </head>
     <body>
-        <div>
-            <p>Thank you</p>
-            <p>You may now close this window and return to the application</p>
-        </div>
+        <header>
+            <h1>Thank you</h1>
+            <p>
+                Now we know it’s you, you can close this window and return to the
+                application.
+            </p>
+        </header>
     </body>
 </html>
diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 944bc9c9ca..b223ca0f56 100644
--- a/synapse/res/templates/sso_error.html
+++ b/synapse/res/templates/sso_error.html
@@ -1,53 +1,68 @@
 <!DOCTYPE html>
 <html lang="en">
-<head>
-    <meta charset="UTF-8">
-    <title>SSO error</title>
-</head>
-<body>
+    <head>
+        <meta charset="UTF-8">
+        <title>Authentication failed</title>
+        <meta name="viewport" content="width=device-width, user-scalable=no">
+        <style type="text/css">
+            {% include "sso.css" without context %}
+
+            #error_code {
+                margin-top: 56px;
+            }
+        </style>
+    </head>
+    <body class="error_page">
 {# If an error of unauthorised is returned it means we have actively rejected their login #}
 {% if error == "unauthorised" %}
-    <p>You are not allowed to log in here.</p>
+        <header>
+            <p>You are not allowed to log in here.</p>
+        </header>
 {% else %}
-    <p>
-        There was an error during authentication:
-    </p>
-    <div id="errormsg" style="margin:20px 80px">{{ error_description | e }}</div>
-    <p>
-        If you are seeing this page after clicking a link sent to you via email, make
-        sure you only click the confirmation link once, and that you open the
-        validation link in the same client you're logging in from.
-    </p>
-    <p>
-        Try logging in again from your Matrix client and if the problem persists
-        please contact the server's administrator.
-    </p>
-    <p>Error: <code>{{ error }}</code></p>
+        <header>
+            <h1>There was an error</h1>
+            <p>
+                <strong id="errormsg">{{ error_description }}</strong>
+            </p>
+            <p>
+                If you are seeing this page after clicking a link sent to you via email,
+                make sure you only click the confirmation link once, and that you open
+                the validation link in the same client you're logging in from.
+            </p>
+            <p>
+                Try logging in again from your Matrix client and if the problem persists
+                please contact the server's administrator.
+            </p>
+            <div id="error_code">
+                <p><strong>Error code</strong></p>
+                <p>{{ error }}</p>
+            </div>
+        </header>
 
-    <script type="text/javascript">
-        // Error handling to support Auth0 errors that we might get through a GET request
-        // to the validation endpoint. If an error is provided, it's either going to be
-        // located in the query string or in a query string-like URI fragment.
-        // We try to locate the error from any of these two locations, but if we can't
-        // we just don't print anything specific.
-        let searchStr = "";
-        if (window.location.search) {
-            // window.location.searchParams isn't always defined when
-            // window.location.search is, so it's more reliable to parse the latter.
-            searchStr = window.location.search;
-        } else if (window.location.hash) {
-            // Replace the # with a ? so that URLSearchParams does the right thing and
-            // doesn't parse the first parameter incorrectly.
-            searchStr = window.location.hash.replace("#", "?");
-        }
+        <script type="text/javascript">
+            // Error handling to support Auth0 errors that we might get through a GET request
+            // to the validation endpoint. If an error is provided, it's either going to be
+            // located in the query string or in a query string-like URI fragment.
+            // We try to locate the error from any of these two locations, but if we can't
+            // we just don't print anything specific.
+            let searchStr = "";
+            if (window.location.search) {
+                // window.location.searchParams isn't always defined when
+                // window.location.search is, so it's more reliable to parse the latter.
+                searchStr = window.location.search;
+            } else if (window.location.hash) {
+                // Replace the # with a ? so that URLSearchParams does the right thing and
+                // doesn't parse the first parameter incorrectly.
+                searchStr = window.location.hash.replace("#", "?");
+            }
 
-        // We might end up with no error in the URL, so we need to check if we have one
-        // to print one.
-        let errorDesc = new URLSearchParams(searchStr).get("error_description")
-        if (errorDesc) {
-            document.getElementById("errormsg").innerText = errorDesc;
-        }
-    </script>
+            // We might end up with no error in the URL, so we need to check if we have one
+            // to print one.
+            let errorDesc = new URLSearchParams(searchStr).get("error_description")
+            if (errorDesc) {
+                document.getElementById("errormsg").innerText = errorDesc;
+            }
+        </script>
 {% endif %}
 </body>
 </html>
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
new file mode 100644
index 0000000000..62a640dad2
--- /dev/null
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -0,0 +1,31 @@
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
+        <title>{{ server_name }} Login</title>
+    </head>
+    <body>
+        <div id="container">
+            <h1 id="title">{{ server_name }} Login</h1>
+            <div class="login_flow">
+                <p>Choose one of the following identity providers:</p>
+            <form>
+                <input type="hidden" name="redirectUrl" value="{{ redirect_url }}">
+                <ul class="radiobuttons">
+{% for p in providers %}
+                    <li>
+                        <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}">
+                        <label for="prov{{ loop.index }}">{{ p.idp_name }}</label>
+{% if p.idp_icon %}
+                        <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/>
+{% endif %}
+                    </li>
+{% endfor %}
+                </ul>
+                <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
+            </form>
+            </div>
+        </div>
+    </body>
+</html>
diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html
new file mode 100644
index 0000000000..8c33787c54
--- /dev/null
+++ b/synapse/res/templates/sso_new_user_consent.html
@@ -0,0 +1,39 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <title>SSO redirect confirmation</title>
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <style type="text/css">
+      {% include "sso.css" without context %}
+
+      #consent_form {
+        margin-top: 56px;
+      }
+    </style>
+</head>
+    <body>
+        <header>
+            <h1>Your account is nearly ready</h1>
+            <p>Agree to the terms to create your account.</p>
+        </header>
+        <main>
+            <!-- {% if user_profile.avatar_url and user_profile.display_name %} -->
+            <div class="profile">
+                <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
+                <div class="profile-details">
+                    <div class="display-name">{{ user_profile.display_name }}</div>
+                    <div class="user-id">{{ user_id }}</div>
+                </div>
+            </div>
+            <!-- {% endif %} -->
+            <form method="post" action="{{my_url}}" id="consent_form">
+                <p>
+                    <input id="accepted_version" type="checkbox" name="accepted_version" value="{{ consent_version }}" required>
+                    <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank">terms and conditions</a>.</label>
+                </p>
+                <input type="submit" class="primary-button" value="Continue"/>
+            </form>
+        </main>
+    </body>
+</html>
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index 20a15e1e74..d1328a6969 100644
--- a/synapse/res/templates/sso_redirect_confirm.html
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -3,12 +3,34 @@
 <head>
     <meta charset="UTF-8">
     <title>SSO redirect confirmation</title>
+    <meta name="viewport" content="width=device-width, user-scalable=no">
+    <style type="text/css">
+      {% include "sso.css" without context %}
+    </style>
 </head>
     <body>
-        <p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
-        <p>If you don't recognise this address, you should ignore this and close this tab.</p>
-        <p>
-            <a href="{{ redirect_url | e }}">I trust this address</a>
-        </p>
+        <header>
+            {% if new_user %}
+            <h1>Your account is now ready</h1>
+            <p>You've made your account on {{ server_name }}.</p>
+            {% else %}
+            <h1>Log in</h1>
+            {% endif %}
+            <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p>
+        </header>
+        <main>
+            {% if user_profile.avatar_url %}
+            <div class="profile">
+                <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" />
+                <div class="profile-details">
+                    {% if user_profile.display_name %}
+                    <div class="display-name">{{ user_profile.display_name }}</div>
+                    {% endif %}
+                    <div class="user-id">{{ user_id }}</div>
+                </div>
+            </div>
+            {% endif %}
+            <a href="{{ redirect_url }}" class="primary-button">Continue</a>
+        </main>
     </body>
-</html>
\ No newline at end of file
+</html>
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 55ddebb4fe..f5c5d164f9 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -36,10 +38,13 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
 from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
 from synapse.rest.admin.rooms import (
     DeleteRoomRestServlet,
+    ForwardExtremitiesRestServlet,
     JoinRoomAliasServlet,
     ListRoomRestServlet,
+    MakeRoomAdminRestServlet,
     RoomMembersRestServlet,
     RoomRestServlet,
+    RoomStateRestServlet,
     ShutdownRoomRestServlet,
 )
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -50,6 +55,7 @@ from synapse.rest.admin.users import (
     PushersRestServlet,
     ResetPasswordRestServlet,
     SearchUsersRestServlet,
+    ShadowBanRestServlet,
     UserAdminServlet,
     UserMediaRestServlet,
     UserMembershipRestServlet,
@@ -208,6 +214,7 @@ def register_servlets(hs, http_server):
     """
     register_servlets_for_client_rest_resource(hs, http_server)
     ListRoomRestServlet(hs).register(http_server)
+    RoomStateRestServlet(hs).register(http_server)
     RoomRestServlet(hs).register(http_server)
     RoomMembersRestServlet(hs).register(http_server)
     DeleteRoomRestServlet(hs).register(http_server)
@@ -228,6 +235,9 @@ def register_servlets(hs, http_server):
     EventReportDetailRestServlet(hs).register(http_server)
     EventReportsRestServlet(hs).register(http_server)
     PushersRestServlet(hs).register(http_server)
+    MakeRoomAdminRestServlet(hs).register(http_server)
+    ShadowBanRestServlet(hs).register(http_server)
+    ForwardExtremitiesRestServlet(hs).register(http_server)
 
 
 def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
 
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
     assert_requester_is_admin,
     assert_user_is_admin,
 )
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
         admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, room_id: str):
+    async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
 
     PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, user_id: str):
+    async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
         "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
     )
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request, server_name: str, media_id: str):
+    async def on_POST(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
         return 200, {}
 
 
+class ProtectMediaByID(RestServlet):
+    """Protect local media from being quarantined.
+    """
+
+    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        logging.info("Protecting local media by ID: %s", media_id)
+
+        # Quarantine this media id
+        await self.store.mark_local_media_as_safe(media_id)
+
+        return 200, {}
+
+
 class ListMediaInRoom(RestServlet):
     """Lists all of the media in a given room.
     """
 
     PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
 class PurgeMediaCacheRestServlet(RestServlet):
     PATTERNS = admin_patterns("/purge_media_cache")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.media_repository = hs.get_media_repository()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_DELETE(self, request, server_name: str, media_id: str):
+    async def on_DELETE(
+        self, request: Request, server_name: str, media_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
 
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
 
-    async def on_POST(self, request, server_name: str):
+    async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
         return 200, {"deleted_media": deleted_media, "total": total}
 
 
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
     """
     Media repo specific APIs.
     """
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
     QuarantineMediaInRoom(hs).register(http_server)
     QuarantineMediaByID(hs).register(http_server)
     QuarantineMediaByUser(hs).register(http_server)
+    ProtectMediaByID(hs).register(http_server)
     ListMediaInRoom(hs).register(http_server)
     DeleteMediaByID(hs).register(http_server)
     DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 25f89e4685..3e57e6a4d0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,10 +14,10 @@
 # limitations under the License.
 import logging
 from http import HTTPStatus
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
-from synapse.api.constants import EventTypes, JoinRules
-from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -25,13 +25,18 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import (
     admin_patterns,
     assert_requester_is_admin,
     assert_user_is_admin,
 )
 from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 logger = logging.getLogger(__name__)
 
@@ -45,12 +50,14 @@ class ShutdownRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -86,13 +93,15 @@ class DeleteRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
         self.pagination_handler = hs.get_pagination_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -146,12 +155,12 @@ class ListRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -236,19 +245,24 @@ class RoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room_with_stats(room_id)
         if not ret:
             raise NotFoundError("Room not found")
 
-        return 200, ret
+        members = await self.store.get_users_in_room(room_id)
+        ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+        return (200, ret)
 
 
 class RoomMembersRestServlet(RestServlet):
@@ -258,12 +272,14 @@ class RoomMembersRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room(room_id)
@@ -276,18 +292,59 @@ class RoomMembersRestServlet(RestServlet):
         return 200, ret
 
 
+class RoomStateRestServlet(RestServlet):
+    """
+    Get full state within a room.
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self._event_serializer = hs.get_event_client_serializer()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        ret = await self.store.get_room(room_id)
+        if not ret:
+            raise NotFoundError("Room not found")
+
+        event_ids = await self.store.get_current_state_ids(room_id)
+        events = await self.store.get_events(event_ids.values())
+        now = self.clock.time_msec()
+        room_state = await self._event_serializer.serialize_events(
+            events.values(),
+            now,
+            # We don't bother bundling aggregations in when asked for state
+            # events, as clients won't use them.
+            bundle_aggregations=False,
+        )
+        ret = {"state": room_state}
+
+        return 200, ret
+
+
 class JoinRoomAliasServlet(RestServlet):
 
     PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_member_handler = hs.get_room_member_handler()
         self.admin_handler = hs.get_admin_handler()
         self.state_handler = hs.get_state_handler()
 
-    async def on_POST(self, request, room_identifier):
+    async def on_POST(
+        self, request: SynapseRequest, room_identifier: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -314,7 +371,6 @@ class JoinRoomAliasServlet(RestServlet):
             handler = self.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
             room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
-            room_id = room_id.to_string()
         else:
             raise SynapseError(
                 400, "%s was not legal room ID or room alias" % (room_identifier,)
@@ -351,3 +407,201 @@ class JoinRoomAliasServlet(RestServlet):
         )
 
         return 200, {"room_id": room_id}
+
+
+class MakeRoomAdminRestServlet(RestServlet):
+    """Allows a server admin to get power in a room if a local user has power in
+    a room. Will also invite the user if they're not in the room and it's a
+    private room. Can specify another user (rather than the admin user) to be
+    granted power, e.g.:
+
+        POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+        {
+            "user_id": "@foo:example.com"
+        }
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.room_member_handler = hs.get_room_member_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
+        self.state_handler = hs.get_state_handler()
+        self.is_mine_id = hs.is_mine_id
+
+    async def on_POST(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+        content = parse_json_object_from_request(request, allow_empty_body=True)
+
+        # Resolve to a room ID, if necessary.
+        if RoomID.is_valid(room_identifier):
+            room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+            room_id = room_id.to_string()
+        else:
+            raise SynapseError(
+                400, "%s was not legal room ID or room alias" % (room_identifier,)
+            )
+
+        # Which user to grant room admin rights to.
+        user_to_add = content.get("user_id", requester.user.to_string())
+
+        # Figure out which local users currently have power in the room, if any.
+        room_state = await self.state_handler.get_current_state(room_id)
+        if not room_state:
+            raise SynapseError(400, "Server not in room")
+
+        create_event = room_state[(EventTypes.Create, "")]
+        power_levels = room_state.get((EventTypes.PowerLevels, ""))
+
+        if power_levels is not None:
+            # We pick the local user with the highest power.
+            user_power = power_levels.content.get("users", {})
+            admin_users = [
+                user_id for user_id in user_power if self.is_mine_id(user_id)
+            ]
+            admin_users.sort(key=lambda user: user_power[user])
+
+            if not admin_users:
+                raise SynapseError(400, "No local admin user in room")
+
+            admin_user_id = None
+
+            for admin_user in reversed(admin_users):
+                if room_state.get((EventTypes.Member, admin_user)):
+                    admin_user_id = admin_user
+                    break
+
+            if not admin_user_id:
+                raise SynapseError(
+                    400, "No local admin user in room",
+                )
+
+            pl_content = power_levels.content
+        else:
+            # If there is no power level events then the creator has rights.
+            pl_content = {}
+            admin_user_id = create_event.sender
+            if not self.is_mine_id(admin_user_id):
+                raise SynapseError(
+                    400, "No local admin user in room",
+                )
+
+        # Grant the user power equal to the room admin by attempting to send an
+        # updated power level event.
+        new_pl_content = dict(pl_content)
+        new_pl_content["users"] = dict(pl_content.get("users", {}))
+        new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
+
+        fake_requester = create_requester(
+            admin_user_id, authenticated_entity=requester.authenticated_entity,
+        )
+
+        try:
+            await self.event_creation_handler.create_and_send_nonmember_event(
+                fake_requester,
+                event_dict={
+                    "content": new_pl_content,
+                    "sender": admin_user_id,
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "room_id": room_id,
+                },
+            )
+        except AuthError:
+            # The admin user we found turned out not to have enough power.
+            raise SynapseError(
+                400, "No local admin user in room with power to update power levels."
+            )
+
+        # Now we check if the user we're granting admin rights to is already in
+        # the room. If not and it's not a public room we invite them.
+        member_event = room_state.get((EventTypes.Member, user_to_add))
+        is_joined = False
+        if member_event:
+            is_joined = member_event.content["membership"] in (
+                Membership.JOIN,
+                Membership.INVITE,
+            )
+
+        if is_joined:
+            return 200, {}
+
+        join_rules = room_state.get((EventTypes.JoinRules, ""))
+        is_public = False
+        if join_rules:
+            is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
+
+        if is_public:
+            return 200, {}
+
+        await self.room_member_handler.update_membership(
+            fake_requester,
+            target=UserID.from_string(user_to_add),
+            room_id=room_id,
+            action=Membership.INVITE,
+        )
+
+        return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+    """Allows a server admin to get or clear forward extremities.
+
+    Clearing does not require restarting the server.
+
+        Clear forward extremities:
+        DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+        Get forward_extremities:
+        GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+    """
+
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.room_member_handler = hs.get_room_member_handler()
+        self.store = hs.get_datastore()
+
+    async def resolve_room_id(self, room_identifier: str) -> str:
+        """Resolve to a room ID, if necessary."""
+        if RoomID.is_valid(room_identifier):
+            resolved_room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+            resolved_room_id = room_id.to_string()
+        else:
+            raise SynapseError(
+                400, "%s was not legal room ID or room alias" % (room_identifier,)
+            )
+        if not resolved_room_id:
+            raise SynapseError(
+                400, "Unknown room ID or room alias %s" % room_identifier
+            )
+        return resolved_room_id
+
+    async def on_DELETE(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+        return 200, {"deleted": deleted_count}
+
+    async def on_GET(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        room_id = await self.resolve_room_id(room_identifier)
+
+        extremities = await self.store.get_forward_extremities_for_room(room_id)
+        return 200, {"count": len(extremities), "results": extremities}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..68c3c64a0d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -42,17 +42,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-_GET_PUSHERS_ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class UsersRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -94,17 +83,32 @@ class UsersRestServletV2(RestServlet):
     The parameter `deactivated` can be used to include deactivated users.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
+
+        if start < 0:
+            raise SynapseError(
+                400,
+                "Query parameter from must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        if limit < 0:
+            raise SynapseError(
+                400,
+                "Query parameter limit must be a string representing a positive integer.",
+                errcode=Codes.INVALID_PARAM,
+            )
+
         user_id = parse_string(request, "user_id", default=None)
         name = parse_string(request, "name", default=None)
         guests = parse_boolean(request, "guests", default=True)
@@ -114,7 +118,7 @@ class UsersRestServletV2(RestServlet):
             start, limit, user_id, name, guests, deactivated
         )
         ret = {"users": users, "total": total}
-        if len(users) >= limit:
+        if (start + limit) < total:
             ret["next_token"] = str(start + len(users))
 
         return 200, ret
@@ -255,7 +259,7 @@ class UserRestServletV2(RestServlet):
 
                 if deactivate and not user["deactivated"]:
                     await self.deactivate_account_handler.deactivate_account(
-                        target_user.to_string(), False
+                        target_user.to_string(), False, requester, by_admin=True
                     )
                 elif not deactivate and user["deactivated"]:
                     if "password" not in body:
@@ -320,9 +324,9 @@ class UserRestServletV2(RestServlet):
                             data={},
                         )
 
-            if "avatar_url" in body and type(body["avatar_url"]) == str:
+            if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
-                    user_id, requester, body["avatar_url"], True
+                    target_user, requester, body["avatar_url"], True
                 )
 
             ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +424,9 @@ class UserRegisterServlet(RestServlet):
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
             raise SynapseError(400, "Invalid user type")
 
+        if "mac" not in body:
+            raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
         got_mac = body["mac"]
 
         want_mac_builder = hmac.new(
@@ -494,12 +501,22 @@ class WhoisRestServlet(RestServlet):
 class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
+        self.store = hs.get_datastore()
+
+    async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        if not self.is_mine(UserID.from_string(target_user_id)):
+            raise SynapseError(400, "Can only deactivate local users")
+
+        if not await self.store.get_user_by_id(target_user_id):
+            raise NotFoundError("User not found")
 
-    async def on_POST(self, request, target_user_id):
-        await assert_requester_is_admin(self.auth, request)
         body = parse_json_object_from_request(request, allow_empty_body=True)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -509,10 +526,8 @@ class DeactivateAccountRestServlet(RestServlet):
                 Codes.BAD_JSON,
             )
 
-        UserID.from_string(target_user_id)
-
         result = await self._deactivate_account_handler.deactivate_account(
-            target_user_id, erase
+            target_user_id, erase, requester, by_admin=True
         )
         if result:
             id_server_unbind_result = "success"
@@ -722,13 +737,6 @@ class UserMembershipRestServlet(RestServlet):
     async def on_GET(self, request, user_id):
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.is_mine(UserID.from_string(user_id)):
-            raise SynapseError(400, "Can only lookup local users")
-
-        user = await self.store.get_user_by_id(user_id)
-        if user is None:
-            raise NotFoundError("Unknown user")
-
         room_ids = await self.store.get_rooms_for_user(user_id)
         ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
         return 200, ret
@@ -767,10 +775,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.store.get_pushers_by_user_id(user_id)
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
-            for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
 
@@ -885,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
         )
 
         return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+    """An admin API for shadow-banning a user.
+
+    A shadow-banned users receives successful responses to their client-server
+    API requests, but the events are not propagated into rooms.
+
+    Shadow-banning a user should be used as a tool of last resort and may lead
+    to confusing or broken behaviour for the client.
+
+    Example:
+
+        POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+        {}
+
+        200 OK
+        {}
+    """
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    async def on_POST(self, request, user_id):
+        await assert_requester_is_admin(self.auth, request)
+
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Only local users can be shadow-banned")
+
+        await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+        return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..0fb9419e58 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     parse_json_object_from_request,
@@ -30,6 +31,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import JsonDict, UserID
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -42,7 +46,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
 
@@ -57,11 +61,14 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
 
         self.auth = hs.get_auth()
 
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
+        self._sso_handler = hs.get_sso_handler()
+
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -86,8 +93,17 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            # While its valid for us to advertise this login type generally,
+            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+
+            if self._msc2858_enabled:
+                sso_flow["org.matrix.msc2858.identity_providers"] = [
+                    _get_auth_flow_dict_for_idp(idp)
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ]
+
+            flows.append(sso_flow)
+
+            # While it's valid for us to advertise this login type generally,
             # synapse currently only gives out these tokens as part of the
             # SSO login flow.
             # Generally we don't want to advertise login flows that clients
@@ -105,22 +121,27 @@ class LoginRestServlet(RestServlet):
         return 200, {"flows": flows}
 
     async def on_POST(self, request: SynapseRequest):
-        self._address_ratelimiter.ratelimit(request.getClientIP())
-
         login_submission = parse_json_object_from_request(request)
 
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 appservice = self.auth.get_appservice_by_req(request)
+
+                if appservice.is_rate_limited():
+                    self._address_ratelimiter.ratelimit(request.getClientIP())
+
                 result = await self._do_appservice_login(login_submission, appservice)
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_token_login(login_submission)
             else:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_other_login(login_submission)
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +180,9 @@ class LoginRestServlet(RestServlet):
         if not appservice.is_interested_in_user(qualified_user_id):
             raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
 
-        return await self._complete_login(qualified_user_id, login_submission)
+        return await self._complete_login(
+            qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+        )
 
     async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
         """Handle non-token/saml/jwt logins
@@ -194,6 +217,7 @@ class LoginRestServlet(RestServlet):
         login_submission: JsonDict,
         callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
+        ratelimit: bool = True,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -208,6 +232,7 @@ class LoginRestServlet(RestServlet):
             callback: Callback function to run after login.
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
+            ratelimit: Whether to ratelimit the login request.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -216,7 +241,8 @@ class LoginRestServlet(RestServlet):
         # Before we actually log them in we check if they've already logged in
         # too often. This happens here rather than before as we don't
         # necessarily know the user before now.
-        self._account_ratelimiter.ratelimit(user_id.lower())
+        if ratelimit:
+            self._account_ratelimiter.ratelimit(user_id.lower())
 
         if create_non_existent_users:
             canonical_uid = await self.auth_handler.check_user_exists(user_id)
@@ -298,48 +324,63 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-class BaseSSORedirectServlet(RestServlet):
-    """Common base class for /login/sso/redirect impls"""
-
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+    """Return an entry for the login flow dict
+
+    Returns an entry suitable for inclusion in "identity_providers" in the
+    response to GET /_matrix/client/r0/login
+    """
+    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    if idp.idp_icon:
+        e["icon"] = idp.idp_icon
+    if idp.idp_brand:
+        e["brand"] = idp.idp_brand
+    return e
+
+
+class SsoRedirectServlet(RestServlet):
+    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+
+    def __init__(self, hs: "HomeServer"):
+        # make sure that the relevant handlers are instantiated, so that they
+        # register themselves with the main SSOHandler.
+        if hs.config.cas_enabled:
+            hs.get_cas_handler()
+        if hs.config.saml2_enabled:
+            hs.get_saml_handler()
+        if hs.config.oidc_enabled:
+            hs.get_oidc_handler()
+        self._sso_handler = hs.get_sso_handler()
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+    def register(self, http_server: HttpServer) -> None:
+        super().register(http_server)
+        if self._msc2858_enabled:
+            # expose additional endpoint for MSC2858 support
+            http_server.register_paths(
+                "GET",
+                client_patterns(
+                    "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+                    releases=(),
+                    unstable=True,
+                ),
+                self.on_GET,
+                self.__class__.__name__,
+            )
 
-    async def on_GET(self, request: SynapseRequest):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = await self.get_sso_url(request, client_redirect_url)
+    async def on_GET(
+        self, request: SynapseRequest, idp_id: Optional[str] = None
+    ) -> None:
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding=None
+        )
+        sso_url = await self._sso_handler.handle_redirect_request(
+            request, client_redirect_url, idp_id,
+        )
+        logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        """Get the URL to redirect to, to perform SSO auth
-
-        Args:
-            request: The client request to redirect.
-            client_redirect_url: the URL that we should redirect the
-                client to when everything is done
-
-        Returns:
-            URL to redirect to
-        """
-        # to be implemented by subclasses
-        raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
-    def __init__(self, hs):
-        self._cas_handler = hs.get_cas_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._cas_handler.get_redirect_url(
-            {"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-
 
 class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -366,40 +407,8 @@ class CasTicketServlet(RestServlet):
         )
 
 
-class SAMLRedirectServlet(BaseSSORedirectServlet):
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._saml_handler = hs.get_saml_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
-    """Implementation for /login/sso/redirect for the OIDC login flow."""
-
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._oidc_handler = hs.get_oidc_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return await self._oidc_handler.handle_redirect_request(
-            request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
-        CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-    elif hs.config.saml2_enabled:
-        SAMLRedirectServlet(hs).register(http_server)
-    elif hs.config.oidc_enabled:
-        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..89823fcc39 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 
 logger = logging.getLogger(__name__)
 
-ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers}
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 93c06afe27..f95627ee61 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.util import json_decoder
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
 if TYPE_CHECKING:
     import synapse.server
@@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
             # provided.
             if server:
                 raise e
-            else:
-                pass
 
         limit = parse_integer(request, "limit", 0)
         since_token = parse_string(request, "since", None)
@@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server, limit=limit, since_token=since_token
@@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server,
@@ -963,25 +977,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs, http_server, is_worker=False):
     RoomStateEventRestServlet(hs).register(http_server)
-    RoomCreateRestServlet(hs).register(http_server)
     RoomMemberListRestServlet(hs).register(http_server)
     JoinedRoomMemberListRestServlet(hs).register(http_server)
     RoomMessageListRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
-    RoomForgetRestServlet(hs).register(http_server)
     RoomMembershipRestServlet(hs).register(http_server)
     RoomSendEventRestServlet(hs).register(http_server)
     PublicRoomListRestServlet(hs).register(http_server)
     RoomStateRestServlet(hs).register(http_server)
     RoomRedactEventRestServlet(hs).register(http_server)
     RoomTypingRestServlet(hs).register(http_server)
-    SearchRestServlet(hs).register(http_server)
-    JoinedRoomsRestServlet(hs).register(http_server)
-    RoomEventServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
-    RoomAliasListServlet(hs).register(http_server)
+
+    # Some servlets only get registered for the main process.
+    if not is_worker:
+        RoomCreateRestServlet(hs).register(http_server)
+        RoomForgetRestServlet(hs).register(http_server)
+        SearchRestServlet(hs).register(http_server)
+        JoinedRoomsRestServlet(hs).register(http_server)
+        RoomEventServlet(hs).register(http_server)
+        RoomAliasListServlet(hs).register(http_server)
 
 
 def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e0feebea94..b67c1702ca 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING
 from urllib.parse import urlparse
 
-if TYPE_CHECKING:
-    from synapse.app.homeserver import HomeServer
-
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
 )
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -46,13 +44,17 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.datastore = hs.get_datastore()
@@ -101,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -189,11 +193,7 @@ class PasswordRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             try:
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
-                    requester,
-                    request,
-                    body,
-                    self.hs.get_ip_from_request(request),
-                    "modify your account password",
+                    requester, request, body, "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
                 # The user needs to provide more steps to complete auth, but
@@ -204,7 +204,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
@@ -215,7 +217,6 @@ class PasswordRestServlet(RestServlet):
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
                     body,
-                    self.hs.get_ip_from_request(request),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -227,7 +228,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
 
@@ -254,14 +257,18 @@ class PasswordRestServlet(RestServlet):
                 logger.error("Auth succeeded but no known type! %r", result.keys())
                 raise SynapseError(500, "", Codes.UNKNOWN)
 
-        # If we have a password in this request, prefer it. Otherwise, there
-        # must be a password hash from an earlier request.
+        # If we have a password in this request, prefer it. Otherwise, use the
+        # password hash from an earlier request.
         if new_password:
             password_hash = await self.auth_handler.hash(new_password)
-        else:
+        elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
+        else:
+            # UI validation was skipped, but the request did not include a new
+            # password.
+            password_hash = None
         if not password_hash:
             raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
 
@@ -300,19 +307,18 @@ class DeactivateAccountRestServlet(RestServlet):
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase
+                requester.user.to_string(), erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "deactivate your account",
+            requester, request, body, "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(), erase, id_server=body.get("id_server")
+            requester.user.to_string(),
+            erase,
+            requester,
+            id_server=body.get("id_server"),
         )
         if result:
             id_server_unbind_result = "success"
@@ -375,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -426,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         super().__init__()
         self.store = self.hs.get_datastore()
@@ -454,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -691,11 +703,7 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a third-party identifier to your account",
+            requester, request, body, "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_account_data_for_user(
-            user_id, account_data_type, body
-        )
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_account_data_for_user(user_id, account_data_type, body)
 
         return 200, {}
 
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
                 " Use /rooms/!roomId:server.name/read_markers",
             )
 
-        max_id = await self.store.add_account_data_to_room(
+        await self.handler.add_account_data_to_room(
             user_id, room_id, account_data_type, body
         )
 
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
         return 200, {}
 
     async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
-
-        # SSO configuration.
-        self._cas_enabled = hs.config.cas_enabled
-        if self._cas_enabled:
-            self._cas_handler = hs.get_cas_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
-        self._oidc_enabled = hs.config.oidc_enabled
-        if self._oidc_enabled:
-            self._oidc_handler = hs.get_oidc_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
         self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
-
-            elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
-            elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
-            else:
-                raise SynapseError(400, "Homeserver not configured for SSO.")
-
-            html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+            html = await self.auth_handler.start_sso_ui_auth(request, session)
 
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"response": response, "session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+                LoginType.RECAPTCHA, authdict, request.getClientIP()
             )
 
             if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+                LoginType.TERMS, authdict, request.getClientIP()
             )
 
             if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove device(s) from your account",
+            requester, request, body, "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove a device from your account",
+            requester, request, body, "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 75215a3779..28b55f27ad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from functools import wraps
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
 logger = logging.getLogger(__name__)
 
 
+def _validate_group_id(f):
+    """Wrapper to validate the form of the group ID.
+
+    Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+    """
+
+    @wraps(f)
+    def wrapper(self, request, group_id, *args, **kwargs):
+        if not GroupID.is_valid(group_id):
+            raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+        return f(self, request, group_id, *args, **kwargs)
+
+    return wrapper
+
+
 class GroupServlet(RestServlet):
     """Get the group profile
     """
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
 
         return 200, group_description
 
+    @_validate_group_id
     async def on_POST(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
-        if not GroupID.is_valid(group_id):
-            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
         result = await self.groups_handler.get_rooms_in_group(
             group_id, requester_user_id
         )
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
 
         return 200, result
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id, config_key):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -590,6 +628,7 @@ class GroupSelfLeaveServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -614,6 +653,7 @@ class GroupSelfJoinServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -638,6 +678,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -662,6 +703,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a device signing key to your account",
+            requester, request, body, "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5374d2c1b6..f0675abd32 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.config.registration import RegistrationConfig
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -125,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
         )
@@ -204,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "msisdn", msisdn
         )
@@ -353,7 +360,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
             )
 
-        ip = self.hs.get_ip_from_request(request)
+        ip = request.getClientIP()
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
 
@@ -451,7 +458,7 @@ class RegisterRestServlet(RestServlet):
 
         # == Normal User Registration == (everyone else)
         if not self._registration_enabled:
-            raise SynapseError(403, "Registration has been disabled")
+            raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
 
         # For regular registration, convert the provided username to lowercase
         # before attempting to register it. This should mean that people who try
@@ -494,11 +501,11 @@ class RegisterRestServlet(RestServlet):
             # user here. We carry on and go through the auth checks though,
             # for paranoia.
             registered_user_id = await self.auth_handler.get_session_data(
-                session_id, "registered_user_id", None
+                session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
             )
             # Extract the previously-hashed password from the session.
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
 
         # Ensure that the username is valid.
@@ -513,11 +520,7 @@ class RegisterRestServlet(RestServlet):
         # not this will raise a user-interactive auth error.
         try:
             auth_result, params, session_id = await self.auth_handler.check_ui_auth(
-                self._registration_flows,
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "register a new account",
+                self._registration_flows, request, body, "register a new account",
             )
         except InteractiveAuthIncompleteError as e:
             # The user needs to provide more steps to complete auth.
@@ -532,7 +535,9 @@ class RegisterRestServlet(RestServlet):
             if not password_hash and password:
                 password_hash = await self.auth_handler.hash(password)
                 await self.auth_handler.set_session_data(
-                    e.session_id, "password_hash", password_hash
+                    e.session_id,
+                    UIAuthSessionDataConstants.PASSWORD_HASH,
+                    password_hash,
                 )
             raise
 
@@ -635,7 +640,9 @@ class RegisterRestServlet(RestServlet):
             # Remember that the user account has been registered (and the user
             # ID it was registered with, since it might not have been specified).
             await self.auth_handler.set_session_data(
-                session_id, "registered_user_id", registered_user_id
+                session_id,
+                UIAuthSessionDataConstants.REGISTERED_USER_ID,
+                registered_user_id,
             )
 
             registered = True
@@ -657,9 +664,13 @@ class RegisterRestServlet(RestServlet):
         user_id = await self.registration_handler.appservice_register(
             username, as_token
         )
-        return await self._create_registration_details(user_id, body)
+        return await self._create_registration_details(
+            user_id, body, is_appservice_ghost=True,
+        )
 
-    async def _create_registration_details(self, user_id, params):
+    async def _create_registration_details(
+        self, user_id, params, is_appservice_ghost=False
+    ):
         """Complete registration of newly-registered user
 
         Allocates device_id if one was not given; also creates access_token.
@@ -676,7 +687,11 @@ class RegisterRestServlet(RestServlet):
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
             device_id, access_token = await self.registration_handler.register_device(
-                user_id, device_id, initial_display_name, is_guest=False
+                user_id,
+                device_id,
+                initial_display_name,
+                is_guest=False,
+                is_appservice_ghost=is_appservice_ghost,
             )
 
             result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..a3dee14ed4 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
 from typing import Tuple
 
 from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
 from synapse.logging.opentracing import set_tag, trace
 from synapse.rest.client.transactions import HttpTransactionCache
 
@@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
+        assert_params_in_dict(content, ("messages",))
 
         sender_user_id = requester.user.to_string()
 
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
     def __init__(self, hs):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, tag):
         requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_tag_to_room(user_id, room_id, tag, body)
 
         return 200, {}
 
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
 
-        max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.remove_tag_from_room(user_id, room_id, tag)
 
         return 200, {}
 
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b3e4d5612e..8b9ef26cf2 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource):
 
         consent_template_directory = hs.config.user_consent_template_dir
 
+        # TODO: switch to synapse.util.templates.build_jinja_env
         loader = jinja2.FileSystemLoader(consent_template_directory)
         self._jinja_env = jinja2.Environment(
             loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f02454..c57ac22e58 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Set
+from typing import Dict
 
 from signedjson.sign import sign_json
 
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
 
         time_now_ms = self.clock.time_msec()
 
-        cache_misses = {}  # type: Dict[str, Set[str]]
+        # Note that the value is unused.
+        cache_misses = {}  # type: Dict[str, Dict[str, int]]
         for (server_name, key_id, from_server), results in cached.items():
             results = [(result["ts_added_ms"], result) for result in results]
 
             if not results and key_id is not None:
-                cache_misses.setdefault(server_name, set()).add(key_id)
+                cache_misses.setdefault(server_name, {})[key_id] = 0
                 continue
 
             if key_id is not None:
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
                     )
 
                 if miss:
-                    cache_misses.setdefault(server_name, set()).add(key_id)
+                    cache_misses.setdefault(server_name, {})[key_id] = 0
                 # Cast to bytes since postgresql returns a memoryview.
                 json_results.add(bytes(most_recent_result["key_json"]))
             else:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f19..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,10 +17,11 @@
 import logging
 import os
 import urllib
-from typing import Awaitable
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple
 
 from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError, cs_error
 from synapse.http.server import finish_request, respond_with_json
@@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
 ]
 
 
-def parse_media_id(request):
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
     try:
         # This allows users to append e.g. /test.png to the URL. Useful for
         # clients that parse the URL to see content type.
@@ -69,7 +70,7 @@ def parse_media_id(request):
         )
 
 
-def respond_404(request):
+def respond_404(request: Request) -> None:
     respond_with_json(
         request,
         404,
@@ -79,8 +80,12 @@ def respond_404(request):
 
 
 async def respond_with_file(
-    request, media_type, file_path, file_size=None, upload_name=None
-):
+    request: Request,
+    media_type: str,
+    file_path: str,
+    file_size: Optional[int] = None,
+    upload_name: Optional[str] = None,
+) -> None:
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
@@ -98,15 +103,20 @@ async def respond_with_file(
         respond_404(request)
 
 
-def add_file_headers(request, media_type, file_size, upload_name):
+def add_file_headers(
+    request: Request,
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str],
+) -> None:
     """Adds the correct response headers in preparation for responding with the
     media.
 
     Args:
-        request (twisted.web.http.Request)
-        media_type (str): The media/content type.
-        file_size (int): Size in bytes of the media, if known.
-        upload_name (str): The name of the requested file, if any.
+        request
+        media_type: The media/content type.
+        file_size: Size in bytes of the media, if known.
+        upload_name: The name of the requested file, if any.
     """
 
     def _quote(x):
@@ -153,7 +163,13 @@ def add_file_headers(request, media_type, file_size, upload_name):
     # select private. don't bother setting Expires as all our
     # clients are smart enough to be happy with Cache-Control
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
-    request.setHeader(b"Content-Length", b"%d" % (file_size,))
+    if file_size is not None:
+        request.setHeader(b"Content-Length", b"%d" % (file_size,))
+
+    # Tell web crawlers to not index, archive, or follow links in media. This
+    # should help to prevent things in the media repo from showing up in web
+    # search results.
+    request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
 
 
 # separators as defined in RFC2616. SP and HT are handled separately.
@@ -179,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
 }
 
 
-def _can_encode_filename_as_token(x):
+def _can_encode_filename_as_token(x: str) -> bool:
     for c in x:
         # from RFC2616:
         #
@@ -201,17 +217,21 @@ def _can_encode_filename_as_token(x):
 
 
 async def respond_with_responder(
-    request, responder, media_type, file_size, upload_name=None
-):
+    request: Request,
+    responder: "Optional[Responder]",
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str] = None,
+) -> None:
     """Responds to the request with given responder. If responder is None then
     returns 404.
 
     Args:
-        request (twisted.web.http.Request)
-        responder (Responder|None)
-        media_type (str): The media/content type.
-        file_size (int|None): Size in bytes of the media. If not known it should be None
-        upload_name (str|None): The name of the requested file, if any.
+        request
+        responder
+        media_type: The media/content type.
+        file_size: Size in bytes of the media. If not known it should be None
+        upload_name: The name of the requested file, if any.
     """
     if request._disconnected:
         logger.warning(
@@ -280,6 +300,7 @@ class FileInfo:
         thumbnail_height (int)
         thumbnail_method (str)
         thumbnail_type (str): Content type of thumbnail, e.g. image/png
+        thumbnail_length (int): The size of the media file, in bytes.
     """
 
     def __init__(
@@ -292,6 +313,7 @@ class FileInfo:
         thumbnail_height=None,
         thumbnail_method=None,
         thumbnail_type=None,
+        thumbnail_length=None,
     ):
         self.server_name = server_name
         self.file_id = file_id
@@ -301,24 +323,25 @@ class FileInfo:
         self.thumbnail_height = thumbnail_height
         self.thumbnail_method = thumbnail_method
         self.thumbnail_type = thumbnail_type
+        self.thumbnail_length = thumbnail_length
 
 
-def get_filename_from_headers(headers):
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
     """
     Get the filename of the downloaded file by inspecting the
     Content-Disposition HTTP header.
 
     Args:
-        headers (dict[bytes, list[bytes]]): The HTTP request headers.
+        headers: The HTTP request headers.
 
     Returns:
-        A Unicode string of the filename, or None.
+        The filename, or None.
     """
     content_disposition = headers.get(b"Content-Disposition", [b""])
 
     # No header, bail out.
     if not content_disposition[0]:
-        return
+        return None
 
     _, params = _parse_header(content_disposition[0])
 
@@ -351,17 +374,16 @@ def get_filename_from_headers(headers):
     return upload_name
 
 
-def _parse_header(line):
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
     """Parse a Content-type like header.
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        line (bytes): header to be parsed
+        line: header to be parsed
 
     Returns:
-        Tuple[bytes, dict[bytes, bytes]]:
-            the main content-type, followed by the parameter dictionary
+        The main content-type, followed by the parameter dictionary
     """
     parts = _parseparam(b";" + line)
     key = next(parts)
@@ -381,16 +403,16 @@ def _parse_header(line):
     return key, pdict
 
 
-def _parseparam(s):
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
     """Generator which splits the input on ;, respecting double-quoted sequences
 
     Cargo-culted from `cgi`, but works on bytes rather than strings.
 
     Args:
-        s (bytes): header to be parsed
+        s: header to be parsed
 
     Returns:
-        Iterable[bytes]: the split input
+        The split input
     """
     while s[:1] == b";":
         s = s[1:]
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 68dd2a1c8a..4e4c6971f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 Will Hunt <will@half-shot.uk>
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,22 +15,29 @@
 # limitations under the License.
 #
 
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 
 class MediaConfigResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         config = hs.get_config()
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
         self.limits_dict = {"m.upload.size": config.max_upload_size}
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         await self.auth.get_user_by_req(request)
         respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d3d8457303..3ed219ae43 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,24 +14,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
-import synapse.http.servlet
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.servlet import parse_boolean
 
 from ._base import parse_media_id, respond_404
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class DownloadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
         self.media_repo = media_repo
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         request.setHeader(
             b"Content-Security-Policy",
@@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
         if server_name == self.server_name:
             await self.media_repo.get_local_media(request, media_id, name)
         else:
-            allow_remote = synapse.http.servlet.parse_boolean(
-                request, "allow_remote", default=True
-            )
+            allow_remote = parse_boolean(request, "allow_remote", default=True)
             if not allow_remote:
                 logger.info(
                     "Rejecting request for remote media %s/%s due to allow_remote",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 9e079f672f..7792f26e78 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,11 +17,12 @@
 import functools
 import os
 import re
+from typing import Callable, List
 
 NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
 
 
-def _wrap_in_base_path(func):
+def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
     """Takes a function that returns a relative path and turns it into an
     absolute path based on the location of the primary media store
     """
@@ -41,12 +43,18 @@ class MediaFilePaths:
     to write to the backup media store (when one is configured)
     """
 
-    def __init__(self, primary_base_path):
+    def __init__(self, primary_base_path: str):
         self.base_path = primary_base_path
 
     def default_thumbnail_rel(
-        self, default_top_level, default_sub_type, width, height, content_type, method
-    ):
+        self,
+        default_top_level: str,
+        default_sub_type: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -55,12 +63,14 @@ class MediaFilePaths:
 
     default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
 
-    def local_media_filepath_rel(self, media_id):
+    def local_media_filepath_rel(self, media_id: str) -> str:
         return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
 
     local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
 
-    def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def local_media_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -86,7 +96,7 @@ class MediaFilePaths:
             media_id[4:],
         )
 
-    def remote_media_filepath_rel(self, server_name, file_id):
+    def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
         )
@@ -94,8 +104,14 @@ class MediaFilePaths:
     remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
 
     def remote_media_thumbnail_rel(
-        self, server_name, file_id, width, height, content_type, method
-    ):
+        self,
+        server_name: str,
+        file_id: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
         return os.path.join(
@@ -113,7 +129,7 @@ class MediaFilePaths:
     # Should be removed after some time, when most of the thumbnails are stored
     # using the new path.
     def remote_media_thumbnail_rel_legacy(
-        self, server_name, file_id, width, height, content_type
+        self, server_name: str, file_id: str, width: int, height: int, content_type: str
     ):
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@@ -126,7 +142,7 @@ class MediaFilePaths:
             file_name,
         )
 
-    def remote_media_thumbnail_dir(self, server_name, file_id):
+    def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
         return os.path.join(
             self.base_path,
             "remote_thumbnail",
@@ -136,7 +152,7 @@ class MediaFilePaths:
             file_id[4:],
         )
 
-    def url_cache_filepath_rel(self, media_id):
+    def url_cache_filepath_rel(self, media_id: str) -> str:
         if NEW_FORMAT_ID_RE.match(media_id):
             # Media id is of the form <DATE><RANDOM_STRING>
             # E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -146,7 +162,7 @@ class MediaFilePaths:
 
     url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
 
-    def url_cache_filepath_dirs_to_delete(self, media_id):
+    def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id file"
         if NEW_FORMAT_ID_RE.match(media_id):
             return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@@ -156,7 +172,9 @@ class MediaFilePaths:
                 os.path.join(self.base_path, "url_cache", media_id[0:2]),
             ]
 
-    def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
+    def url_cache_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -178,7 +196,7 @@ class MediaFilePaths:
 
     url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
 
-    def url_cache_thumbnail_directory(self, media_id):
+    def url_cache_thumbnail_directory(self, media_id: str) -> str:
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
 
@@ -195,7 +213,7 @@ class MediaFilePaths:
                 media_id[4:],
             )
 
-    def url_cache_thumbnail_dirs_to_delete(self, media_id):
+    def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
         "The dirs to try and remove if we delete the media_id thumbnails"
         # Media id is of the form <DATE><RANDOM_STRING>
         # E.g.: 2017-09-28-fsdRDt24DS234dsf
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..4c9946a616 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import errno
 import logging
 import os
 import shutil
-from typing import IO, Dict, List, Optional, Tuple
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 import twisted.internet.error
 import twisted.web.http
@@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
 from .thumbnailer import Thumbnailer, ThumbnailError
 from .upload_resource import UploadResource
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -63,26 +66,26 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
         self.max_upload_size = hs.config.max_upload_size
         self.max_image_pixels = hs.config.max_image_pixels
 
-        self.primary_base_path = hs.config.media_store_path
-        self.filepaths = MediaFilePaths(self.primary_base_path)
+        self.primary_base_path = hs.config.media_store_path  # type: str
+        self.filepaths = MediaFilePaths(self.primary_base_path)  # type: MediaFilePaths
 
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
-        self.recently_accessed_remotes = set()
-        self.recently_accessed_locals = set()
+        self.recently_accessed_remotes = set()  # type: Set[Tuple[str, str]]
+        self.recently_accessed_locals = set()  # type: Set[str]
 
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
@@ -113,7 +116,7 @@ class MediaRepository:
             "update_recently_accessed_media", self._update_recently_accessed
         )
 
-    async def _update_recently_accessed(self):
+    async def _update_recently_accessed(self) -> None:
         remote_media = self.recently_accessed_remotes
         self.recently_accessed_remotes = set()
 
@@ -124,12 +127,12 @@ class MediaRepository:
             local_media, remote_media, self.clock.time_msec()
         )
 
-    def mark_recently_accessed(self, server_name, media_id):
+    def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
         """Mark the given media as recently accessed.
 
         Args:
-            server_name (str|None): Origin server of media, or None if local
-            media_id (str): The media ID of the content
+            server_name: Origin server of media, or None if local
+            media_id: The media ID of the content
         """
         if server_name:
             self.recently_accessed_remotes.add((server_name, media_id))
@@ -459,7 +462,14 @@ class MediaRepository:
     def _get_thumbnail_requirements(self, media_type):
         return self.thumbnail_requirements.get(media_type, ())
 
-    def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
+    def _generate_thumbnail(
+        self,
+        thumbnailer: Thumbnailer,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[BytesIO]:
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
@@ -470,22 +480,20 @@ class MediaRepository:
                 m_height,
                 self.max_image_pixels,
             )
-            return
+            return None
 
         if thumbnailer.transpose_method is not None:
             m_width, m_height = thumbnailer.transpose()
 
         if t_method == "crop":
-            t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
+            return thumbnailer.crop(t_width, t_height, t_type)
         elif t_method == "scale":
             t_width, t_height = thumbnailer.aspect(t_width, t_height)
             t_width = min(m_width, t_width)
             t_height = min(m_height, t_height)
-            t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
-        else:
-            t_byte_source = None
+            return thumbnailer.scale(t_width, t_height, t_type)
 
-        return t_byte_source
+        return None
 
     async def generate_local_exact_thumbnail(
         self,
@@ -776,7 +784,7 @@ class MediaRepository:
 
         return {"width": m_width, "height": m_height}
 
-    async def delete_old_remote_media(self, before_ts):
+    async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
         old_media = await self.store.get_remote_media_before(before_ts)
 
         deleted = 0
@@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
     within a given rectangle.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         # If we're not configured to use it, raise if we somehow got here.
         if not hs.config.can_load_media_repo:
             raise ConfigError("Synapse is not configured to use a media repo.")
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 268e0c8f50..89cdd605aa 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vecotr Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@ import os
 import shutil
 from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
 
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
 from twisted.protocols.basic import FileSender
 
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@@ -270,7 +272,7 @@ class MediaStorage:
         return self.filepaths.local_media_filepath_rel(file_info.file_id)
 
 
-def _write_file_synchronously(source, dest):
+def _write_file_synchronously(source: IO, dest: IO) -> None:
     """Write `source` to the file like `dest` synchronously. Should be called
     from a thread.
 
@@ -286,14 +288,14 @@ class FileResponder(Responder):
     """Wraps an open file that can be sent to a request.
 
     Args:
-        open_file (file): A file like object to be streamed ot the client,
+        open_file: A file like object to be streamed ot the client,
             is closed when finished streaming.
     """
 
-    def __init__(self, open_file):
+    def __init__(self, open_file: IO):
         self.open_file = open_file
 
-    def write_to_consumer(self, consumer):
+    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
         return make_deferred_yieldable(
             FileSender().beginFileTransfer(self.open_file, consumer)
         )
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d168..bf3be653aa 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import datetime
 import errno
 import fnmatch
@@ -23,12 +23,13 @@ import re
 import shutil
 import sys
 import traceback
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
 from urllib import parse as urlparse
 
 import attr
 
 from twisted.internet.error import DNSLookupError
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.client import SimpleHttpClient
@@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
 
 from ._base import FileInfo
 
+if TYPE_CHECKING:
+    from lxml import etree
+
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@@ -119,7 +127,12 @@ class OEmbedError(Exception):
 class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.auth = hs.get_auth()
@@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
                 self._start_expire_url_cache_data, 10 * 1000
             )
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         request.setHeader(b"Allow", b"OPTIONS, GET")
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
 
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
@@ -373,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Check whether the URL should be downloaded as oEmbed content instead.
 
-        Params:
+        Args:
             url: The URL to check.
 
         Returns:
@@ -390,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         """
         Request content from an oEmbed endpoint.
 
-        Params:
+        Args:
             endpoint: The oEmbed API endpoint.
             url: The URL to pass to the API.
 
@@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
             raise OEmbedError() from e
 
-    async def _download_url(self, url: str, user):
+    async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             "expire_url_cache_data", self._expire_url_cache_data
         )
 
-    async def _expire_url_cache_data(self):
+    async def _expire_url_cache_data(self) -> None:
         """Clean up expired url cache content, media and thumbnails.
         """
         # TODO: Delete from backup media store
@@ -676,24 +689,54 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def decode_and_calc_og(
+    body: bytes, media_uri: str, request_encoding: Optional[str] = None
+) -> Dict[str, Optional[str]]:
+    """
+    Calculate metadata for an HTML document.
+
+    This uses lxml to parse the HTML document into the OG response. If errors
+    occur during processing of the document, an empty response is returned.
+
+    Args:
+        body: The HTML document, as bytes.
+        media_url: The URI used to download the body.
+        request_encoding: The character encoding of the body, as a string.
+
+    Returns:
+        The OG response as a dictionary.
+    """
+    # If there's no body, nothing useful is going to be found.
+    if not body:
+        return {}
+
     from lxml import etree
 
+    # Create an HTML parser. If this fails, log and return no metadata.
     try:
         parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body, parser)
-        og = _calc_og(tree, media_uri)
+    except LookupError:
+        # blindly consider the encoding as utf-8.
+        parser = etree.HTMLParser(recover=True, encoding="utf-8")
+    except Exception as e:
+        logger.warning("Unable to create HTML parser: %s" % (e,))
+        return {}
+
+    def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+        # Attempt to parse the body. If this fails, log and return no metadata.
+        tree = etree.fromstring(body_attempt, parser)
+        return _calc_og(tree, media_uri)
+
+    # Attempt to parse the body. If this fails, log and return no metadata.
+    try:
+        return _attempt_calc_og(body)
     except UnicodeDecodeError:
         # blindly try decoding the body as utf-8, which seems to fix
         # the charset mismatches on https://google.com
-        parser = etree.HTMLParser(recover=True, encoding=request_encoding)
-        tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
-        og = _calc_og(tree, media_uri)
-
-    return og
+        return _attempt_calc_og(body.decode("utf-8", "ignore"))
 
 
-def _calc_og(tree, media_uri):
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
     # suck our tree into lxml and define our OG response.
 
     # if we see any image URLs in the OG response, then spider them
@@ -797,7 +840,9 @@ def _calc_og(tree, media_uri):
                 for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
             )
             og["og:description"] = summarize_paragraphs(text_nodes)
-    else:
+    elif og["og:description"]:
+        # This must be a non-empty string at this point.
+        assert isinstance(og["og:description"], str)
         og["og:description"] = summarize_paragraphs([og["og:description"]])
 
     # TODO: delete the url downloads to stop diskfilling,
@@ -805,7 +850,9 @@ def _calc_og(tree, media_uri):
     return og
 
 
-def _iterate_over_text(tree, *tags_to_ignore):
+def _iterate_over_text(
+    tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
     """Iterate over the tree returning text nodes in a depth first fashion,
     skipping text nodes inside certain tags.
     """
@@ -839,32 +886,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
             )
 
 
-def _rebase_url(url, base):
-    base = list(urlparse.urlparse(base))
-    url = list(urlparse.urlparse(url))
-    if not url[0]:  # fix up schema
-        url[0] = base[0] or "http"
-    if not url[1]:  # fix up hostname
-        url[1] = base[1]
-        if not url[2].startswith("/"):
-            url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
-    return urlparse.urlunparse(url)
+def _rebase_url(url: str, base: str) -> str:
+    base_parts = list(urlparse.urlparse(base))
+    url_parts = list(urlparse.urlparse(url))
+    if not url_parts[0]:  # fix up schema
+        url_parts[0] = base_parts[0] or "http"
+    if not url_parts[1]:  # fix up hostname
+        url_parts[1] = base_parts[1]
+        if not url_parts[2].startswith("/"):
+            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+    return urlparse.urlunparse(url_parts)
 
 
-def _is_media(content_type):
-    if content_type.lower().startswith("image/"):
-        return True
+def _is_media(content_type: str) -> bool:
+    return content_type.lower().startswith("image/")
 
 
-def _is_html(content_type):
+def _is_html(content_type: str) -> bool:
     content_type = content_type.lower()
-    if content_type.startswith("text/html") or content_type.startswith(
+    return content_type.startswith("text/html") or content_type.startswith(
         "application/xhtml"
-    ):
-        return True
+    )
 
 
-def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+def summarize_paragraphs(
+    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
     # Try to get a summary of between 200 and 500 words, respecting
     # first paragraph and then word boundaries.
     # TODO: Respect sentences?
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..e92006faa9 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,27 +13,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
+import abc
 import logging
 import os
 import shutil
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
 
 from ._base import FileInfo, Responder
 from .media_storage import FileResponder
 
 logger = logging.getLogger(__name__)
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
-class StorageProvider:
+
+class StorageProvider(metaclass=abc.ABCMeta):
     """A storage provider is a service that can store uploaded media and
     retrieve them.
     """
 
-    async def store_file(self, path: str, file_info: FileInfo):
+    @abc.abstractmethod
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """Store the file described by file_info. The actual contents can be
         retrieved by reading the file in file_info.upload_path.
 
@@ -42,6 +47,7 @@ class StorageProvider:
             file_info: The metadata of the file.
         """
 
+    @abc.abstractmethod
     async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """Attempt to fetch the file described by file_info and stream it
         into writer.
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
         self.store_synchronous = store_synchronous
         self.store_remote = store_remote
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "StorageProviderWrapper[%s]" % (self.backend,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         if not file_info.server_name and not self.store_local:
             return None
 
@@ -91,39 +97,34 @@ class StorageProviderWrapper(StorageProvider):
         if self.store_synchronous:
             # store_file is supposed to return an Awaitable, but guard
             # against improper implementations.
-            result = self.backend.store_file(path, file_info)
-            if inspect.isawaitable(result):
-                return await result
+            await maybe_awaitable(self.backend.store_file(path, file_info))  # type: ignore
         else:
             # TODO: Handle errors.
             async def store():
                 try:
-                    result = self.backend.store_file(path, file_info)
-                    if inspect.isawaitable(result):
-                        return await result
+                    return await maybe_awaitable(
+                        self.backend.store_file(path, file_info)
+                    )
                 except Exception:
                     logger.exception("Error storing file")
 
             run_in_background(store)
-            return None
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         # store_file is supposed to return an Awaitable, but guard
         # against improper implementations.
-        result = self.backend.fetch(path, file_info)
-        if inspect.isawaitable(result):
-            return await result
+        return await maybe_awaitable(self.backend.fetch(path, file_info))
 
 
 class FileStorageProviderBackend(StorageProvider):
     """A storage provider that stores files in a directory on a filesystem.
 
     Args:
-        hs (HomeServer)
+        hs
         config: The config returned by `parse_config`.
     """
 
-    def __init__(self, hs, config):
+    def __init__(self, hs: "HomeServer", config: str):
         self.hs = hs
         self.cache_directory = hs.config.media_store_path
         self.base_directory = config
@@ -131,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
     def __str__(self):
         return "FileStorageProviderBackend[%s]" % (self.base_directory,)
 
-    async def store_file(self, path, file_info):
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
         """See StorageProvider.store_file"""
 
         primary_fname = os.path.join(self.cache_directory, path)
@@ -141,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
         if not os.path.exists(dirname):
             os.makedirs(dirname)
 
-        return await defer_to_thread(
+        await defer_to_thread(
             self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
         )
 
-    async def fetch(self, path, file_info):
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
         """See StorageProvider.fetch"""
 
         backup_fname = os.path.join(self.base_directory, path)
         if os.path.isfile(backup_fname):
             return FileResponder(open(backup_fname, "rb"))
 
+        return None
+
     @staticmethod
-    def parse_config(config):
+    def parse_config(config: dict) -> str:
         """Called on startup to parse config supplied. This should parse
         the config and raise if there is a problem.
 
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 30421b663a..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +16,14 @@
 
 
 import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from twisted.web.http import Request
 
 from synapse.api.errors import SynapseError
 from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
+from synapse.rest.media.v1.media_storage import MediaStorage
 
 from ._base import (
     FileInfo,
@@ -28,13 +33,22 @@ from ._base import (
     respond_with_responder,
 )
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class ThumbnailResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo, media_storage):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
         super().__init__()
 
         self.store = hs.get_datastore()
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
 
-    async def _async_render_GET(self, request):
+    async def _async_render_GET(self, request: Request) -> None:
         set_cors_headers(request)
         server_name, media_id, _ = parse_media_id(request)
         width = parse_integer(request, "width", required=True)
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
     async def _respond_local_thumbnail(
-        self, request, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -86,41 +106,27 @@ class ThumbnailResource(DirectServeJsonResource):
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
-        if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
-            )
-
-            file_info = FileInfo(
-                server_name=None,
-                file_id=media_id,
-                url_cache=media_info["url_cache"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
-
-            responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
-        else:
-            logger.info("Couldn't find any generated thumbnails")
-            respond_404(request)
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_id,
+            url_cache=media_info["url_cache"],
+            server_name=None,
+        )
 
     async def _select_or_generate_local_thumbnail(
         self,
-        request,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
@@ -178,14 +184,14 @@ class ThumbnailResource(DirectServeJsonResource):
 
     async def _select_or_generate_remote_thumbnail(
         self,
-        request,
-        server_name,
-        media_id,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-    ):
+        request: Request,
+        server_name: str,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+    ) -> None:
         media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
@@ -239,8 +245,15 @@ class ThumbnailResource(DirectServeJsonResource):
             raise SynapseError(400, "Failed to generate thumbnail.")
 
     async def _respond_remote_thumbnail(
-        self, request, server_name, media_id, width, height, method, m_type
-    ):
+        self,
+        request: Request,
+        server_name: str,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+    ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
@@ -249,97 +262,185 @@ class ThumbnailResource(DirectServeJsonResource):
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_info["filesystem_id"],
+            url_cache=None,
+            server_name=server_name,
+        )
 
+    async def _select_and_respond_with_thumbnail(
+        self,
+        request: Request,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str] = None,
+        server_name: Optional[str] = None,
+    ) -> None:
+        """
+        Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            request: The incoming request.
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+        """
         if thumbnail_infos:
-            thumbnail_info = self._select_thumbnail(
-                width, height, method, m_type, thumbnail_infos
+            file_info = self._select_thumbnail(
+                desired_width,
+                desired_height,
+                desired_method,
+                desired_type,
+                thumbnail_infos,
+                file_id,
+                url_cache,
+                server_name,
             )
-            file_info = FileInfo(
-                server_name=server_name,
-                file_id=media_info["filesystem_id"],
-                thumbnail=True,
-                thumbnail_width=thumbnail_info["thumbnail_width"],
-                thumbnail_height=thumbnail_info["thumbnail_height"],
-                thumbnail_type=thumbnail_info["thumbnail_type"],
-                thumbnail_method=thumbnail_info["thumbnail_method"],
-            )
-
-            t_type = file_info.thumbnail_type
-            t_length = thumbnail_info["thumbnail_length"]
+            if not file_info:
+                logger.info("Couldn't find a thumbnail matching the desired inputs")
+                respond_404(request)
+                return
 
             responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(request, responder, t_type, t_length)
+            await respond_with_responder(
+                request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+            )
         else:
             logger.info("Failed to find any generated thumbnails")
             respond_404(request)
 
     def _select_thumbnail(
         self,
-        desired_width,
-        desired_height,
-        desired_method,
-        desired_type,
-        thumbnail_infos,
-    ):
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[Dict[str, Any]],
+        file_id: str,
+        url_cache: Optional[str],
+        server_name: Optional[str],
+    ) -> Optional[FileInfo]:
+        """
+        Choose an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of dictionaries of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: The URL cache value.
+            server_name: The server name, if this is a remote thumbnail.
+
+        Returns:
+             The thumbnail which best matches the desired parameters.
+        """
+        desired_method = desired_method.lower()
+
+        # The chosen thumbnail.
+        thumbnail_info = None
+
         d_w = desired_width
         d_h = desired_height
 
-        if desired_method.lower() == "crop":
+        if desired_method == "crop":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             crop_info_list = []
+            # Other thumbnails.
             crop_info_list2 = []
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "crop":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
-                if t_method == "crop":
-                    aspect_quality = abs(d_w * t_h - d_h * t_w)
-                    min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
-                    size_quality = abs((d_w - t_w) * (d_h - t_h))
-                    type_quality = desired_type != info["thumbnail_type"]
-                    length_quality = info["thumbnail_length"]
-                    if t_w >= d_w or t_h >= d_h:
-                        crop_info_list.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                aspect_quality = abs(d_w * t_h - d_h * t_w)
+                min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+                size_quality = abs((d_w - t_w) * (d_h - t_h))
+                type_quality = desired_type != info["thumbnail_type"]
+                length_quality = info["thumbnail_length"]
+                if t_w >= d_w or t_h >= d_h:
+                    crop_info_list.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
-                    else:
-                        crop_info_list2.append(
-                            (
-                                aspect_quality,
-                                min_quality,
-                                size_quality,
-                                type_quality,
-                                length_quality,
-                                info,
-                            )
+                    )
+                else:
+                    crop_info_list2.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
                         )
+                    )
             if crop_info_list:
-                return min(crop_info_list)[-1]
-            else:
-                return min(crop_info_list2)[-1]
-        else:
+                thumbnail_info = min(crop_info_list)[-1]
+            elif crop_info_list2:
+                thumbnail_info = min(crop_info_list2)[-1]
+        elif desired_method == "scale":
+            # Thumbnails that match equal or larger sizes of desired width/height.
             info_list = []
+            # Other thumbnails.
             info_list2 = []
+
             for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info["thumbnail_method"] != "scale":
+                    continue
+
                 t_w = info["thumbnail_width"]
                 t_h = info["thumbnail_height"]
-                t_method = info["thumbnail_method"]
                 size_quality = abs((d_w - t_w) * (d_h - t_h))
                 type_quality = desired_type != info["thumbnail_type"]
                 length_quality = info["thumbnail_length"]
-                if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+                if t_w >= d_w or t_h >= d_h:
                     info_list.append((size_quality, type_quality, length_quality, info))
-                elif t_method == "scale":
+                else:
                     info_list2.append(
                         (size_quality, type_quality, length_quality, info)
                     )
             if info_list:
-                return min(info_list)[-1]
-            else:
-                return min(info_list2)[-1]
+                thumbnail_info = min(info_list)[-1]
+            elif info_list2:
+                thumbnail_info = min(info_list2)[-1]
+
+        if thumbnail_info:
+            return FileInfo(
+                file_id=file_id,
+                url_cache=url_cache,
+                server_name=server_name,
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+                thumbnail_length=thumbnail_info["thumbnail_length"],
+            )
+
+        # No matching thumbnail was found.
+        return None
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 32a8e4f960..07903e4017 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
 # limitations under the License.
 import logging
 from io import BytesIO
+from typing import Tuple
 
 from PIL import Image
 
@@ -39,7 +41,7 @@ class Thumbnailer:
 
     FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
 
-    def __init__(self, input_path):
+    def __init__(self, input_path: str):
         try:
             self.image = Image.open(input_path)
         except OSError as e:
@@ -59,11 +61,11 @@ class Thumbnailer:
             # A lot of parsing errors can happen when parsing EXIF
             logger.info("Error parsing image EXIF information: %s", e)
 
-    def transpose(self):
+    def transpose(self) -> Tuple[int, int]:
         """Transpose the image using its EXIF Orientation tag
 
         Returns:
-            Tuple[int, int]: (width, height) containing the new image size in pixels.
+            A tuple containing the new image size in pixels as (width, height).
         """
         if self.transpose_method is not None:
             self.image = self.image.transpose(self.transpose_method)
@@ -73,7 +75,7 @@ class Thumbnailer:
             self.image.info["exif"] = None
         return self.image.size
 
-    def aspect(self, max_width, max_height):
+    def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
         """Calculate the largest size that preserves aspect ratio which
         fits within the given rectangle::
 
@@ -91,7 +93,7 @@ class Thumbnailer:
         else:
             return (max_height * self.width) // self.height, max_height
 
-    def _resize(self, width, height):
+    def _resize(self, width: int, height: int) -> Image:
         # 1-bit or 8-bit color palette images need converting to RGB
         # otherwise they will be scaled using nearest neighbour which
         # looks awful
@@ -99,7 +101,7 @@ class Thumbnailer:
             self.image = self.image.convert("RGB")
         return self.image.resize((width, height), Image.ANTIALIAS)
 
-    def scale(self, width, height, output_type):
+    def scale(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales the image to the given dimensions.
 
         Returns:
@@ -108,7 +110,7 @@ class Thumbnailer:
         scaled = self._resize(width, height)
         return self._encode_image(scaled, output_type)
 
-    def crop(self, width, height, output_type):
+    def crop(self, width: int, height: int, output_type: str) -> BytesIO:
         """Rescales and crops the image to the given dimensions preserving
         aspect::
             (w_in / h_in) = (w_scaled / h_scaled)
@@ -136,7 +138,7 @@ class Thumbnailer:
             cropped = scaled_image.crop((crop_left, 0, crop_right, height))
         return self._encode_image(cropped, output_type)
 
-    def _encode_image(self, output_image, output_type):
+    def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
         output_bytes_io = BytesIO()
         fmt = self.FORMATS[output_type]
         if fmt == "JPEG":
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..6da76ae994 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,18 +15,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.rest.media.v1.media_repository import MediaRepository
+
 logger = logging.getLogger(__name__)
 
 
 class UploadResource(DirectServeJsonResource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
 
         self.media_repo = media_repo
@@ -37,14 +45,14 @@ class UploadResource(DirectServeJsonResource):
         self.max_upload_size = hs.config.max_upload_size
         self.clock = hs.get_clock()
 
-    async def _async_render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request: Request) -> None:
         respond_with_json(request, 200, {}, send_cors=True)
 
-    async def _async_render_POST(self, request):
+    async def _async_render_POST(self, request: Request) -> None:
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
-        content_length = request.getHeader(b"Content-Length").decode("ascii")
+        content_length = request.getHeader("Content-Length")
         if content_length is None:
             raise SynapseError(msg="Request must specify a Content-Length", code=400)
         if int(content_length) > self.max_upload_size:
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index c0b733488b..e5ef515090 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,3 +12,55 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from typing import TYPE_CHECKING, Mapping
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]:
+    """Builds a resource tree to include synapse-specific client resources
+
+    These are resources which should be loaded on all workers which expose a C-S API:
+    ie, the main process, and any generic workers so configured.
+
+    Returns:
+         map from path to Resource.
+    """
+    resources = {
+        # SSO bits. These are always loaded, whether or not SSO login is actually
+        # enabled (they just won't work very well if it's not)
+        "/_synapse/client/pick_idp": PickIdpResource(hs),
+        "/_synapse/client/pick_username": pick_username_resource(hs),
+        "/_synapse/client/new_user_consent": NewUserConsentResource(hs),
+        "/_synapse/client/sso_register": SsoRegisterResource(hs),
+    }
+
+    # provider-specific SSO bits. Only load these if they are enabled, since they
+    # rely on optional dependencies.
+    if hs.config.oidc_enabled:
+        from synapse.rest.synapse.client.oidc import OIDCResource
+
+        resources["/_synapse/client/oidc"] = OIDCResource(hs)
+
+    if hs.config.saml2_enabled:
+        from synapse.rest.synapse.client.saml2 import SAML2Resource
+
+        res = SAML2Resource(hs)
+        resources["/_synapse/client/saml2"] = res
+
+        # This is also mounted under '/_matrix' for backwards-compatibility.
+        resources["/_matrix/saml2"] = res
+
+    return resources
+
+
+__all__ = ["build_synapse_client_resource_tree"]
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
new file mode 100644
index 0000000000..b2e0f93810
--- /dev/null
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class NewUserConsentResource(DirectServeHtmlResource):
+    """A resource which collects consent to the server's terms from a new user
+
+    This resource gets mounted at /_synapse/client/new_user_consent, and is shown
+    when we are automatically creating a new user due to an SSO login.
+
+    It shows a template which prompts the user to go and read the Ts and Cs, and click
+    a clickybox if they have done so.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._server_name = hs.hostname
+        self._consent_version = hs.config.consent.user_consent_version
+
+        def template_search_dirs():
+            if hs.config.sso.sso_template_dir:
+                yield hs.config.sso.sso_template_dir
+            yield hs.config.sso.default_template_dir
+
+        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+            session = self._sso_handler.get_mapping_session(session_id)
+        except SynapseError as e:
+            logger.warning("Error fetching session: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        user_id = UserID(session.chosen_localpart, self._server_name)
+        user_profile = {
+            "display_name": session.display_name,
+        }
+
+        template_params = {
+            "user_id": user_id.to_string(),
+            "user_profile": user_profile,
+            "consent_version": self._consent_version,
+            "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,),
+        }
+
+        template = self._jinja_env.get_template("sso_new_user_consent.html")
+        html = template.render(template_params)
+        respond_with_html(request, 200, html)
+
+    async def _async_render_POST(self, request: Request):
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        try:
+            accepted_version = parse_string(request, "accepted_version", required=True)
+        except SynapseError as e:
+            self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+            return
+
+        await self._sso_handler.handle_terms_accepted(
+            request, session_id, accepted_version
+        )
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index d958dd65bb..64c0deb75d 100644
--- a/synapse/rest/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -12,11 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from twisted.web.resource import Resource
 
-from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
 
 logger = logging.getLogger(__name__)
 
@@ -25,3 +26,6 @@ class OIDCResource(Resource):
     def __init__(self, hs):
         Resource.__init__(self)
         self.putChild(b"callback", OIDCCallbackResource(hs))
+
+
+__all__ = ["OIDCResource"]
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index f7a0bc4bdb..f7a0bc4bdb 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
new file mode 100644
index 0000000000..9550b82998
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    finish_request,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PickIdpResource(DirectServeHtmlResource):
+    """IdP picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
+    which prompts the user to choose an Identity Provider from the list.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._sso_login_idp_picker_template = (
+            hs.config.sso.sso_login_idp_picker_template
+        )
+        self._server_name = hs.hostname
+
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding="utf-8"
+        )
+        idp = parse_string(request, "idp", required=False)
+
+        # if we need to pick an IdP, do so
+        if not idp:
+            return await self._serve_id_picker(request, client_redirect_url)
+
+        # otherwise, redirect to the IdP's redirect URI
+        providers = self._sso_handler.get_identity_providers()
+        auth_provider = providers.get(idp)
+        if not auth_provider:
+            logger.info("Unknown idp %r", idp)
+            self._sso_handler.render_error(
+                request, "unknown_idp", "Unknown identity provider ID"
+            )
+            return
+
+        sso_url = await auth_provider.handle_redirect_request(
+            request, client_redirect_url.encode("utf8")
+        )
+        logger.info("Redirecting to %s", sso_url)
+        request.redirect(sso_url)
+        finish_request(request)
+
+    async def _serve_id_picker(
+        self, request: SynapseRequest, client_redirect_url: str
+    ) -> None:
+        # otherwise, serve up the IdP picker
+        providers = self._sso_handler.get_identity_providers()
+        html = self._sso_login_idp_picker_template.render(
+            redirect_url=client_redirect_url,
+            server_name=self._server_name,
+            providers=providers.values(),
+        )
+        respond_with_html(request, 200, html)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
new file mode 100644
index 0000000000..96077cfcd1
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, List
+
+from twisted.web.http import Request
+from twisted.web.resource import Resource
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_boolean, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+def pick_username_resource(hs: "HomeServer") -> Resource:
+    """Factory method to generate the username picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_username and has two
+       children:
+
+      * "account_details": renders the form and handles the POSTed response
+      * "check": a JSON endpoint which checks if a userid is free.
+    """
+
+    res = Resource()
+    res.putChild(b"account_details", AccountDetailsResource(hs))
+    res.putChild(b"check", AvailabilityCheckResource(hs))
+
+    return res
+
+
+class AvailabilityCheckResource(DirectServeJsonResource):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_GET(self, request: Request):
+        localpart = parse_string(request, "username", required=True)
+
+        session_id = get_username_mapping_session_cookie_from_request(request)
+
+        is_available = await self._sso_handler.check_username_availability(
+            localpart, session_id
+        )
+        return 200, {"available": is_available}
+
+
+class AccountDetailsResource(DirectServeHtmlResource):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+        def template_search_dirs():
+            if hs.config.sso.sso_template_dir:
+                yield hs.config.sso.sso_template_dir
+            yield hs.config.sso.default_template_dir
+
+        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+            session = self._sso_handler.get_mapping_session(session_id)
+        except SynapseError as e:
+            logger.warning("Error fetching session: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        idp_id = session.auth_provider_id
+        template_params = {
+            "idp": self._sso_handler.get_identity_providers()[idp_id],
+            "user_attributes": {
+                "display_name": session.display_name,
+                "emails": session.emails,
+            },
+        }
+
+        template = self._jinja_env.get_template("sso_auth_account_details.html")
+        html = template.render(template_params)
+        respond_with_html(request, 200, html)
+
+    async def _async_render_POST(self, request: SynapseRequest):
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+
+        try:
+            localpart = parse_string(request, "username", required=True)
+            use_display_name = parse_boolean(request, "use_display_name", default=False)
+
+            try:
+                emails_to_use = [
+                    val.decode("utf-8") for val in request.args.get(b"use_email", [])
+                ]  # type: List[str]
+            except ValueError:
+                raise SynapseError(400, "Query parameter use_email must be utf-8")
+        except SynapseError as e:
+            logger.warning("[session %s] bad param: %s", session_id, e)
+            self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+            return
+
+        await self._sso_handler.handle_submit_username_request(
+            request, session_id, localpart, use_display_name, emails_to_use
+        )
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 68da37ca6a..3e8235ee1e 100644
--- a/synapse/rest/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -12,12 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 
 from twisted.web.resource import Resource
 
-from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
-from synapse.rest.saml2.response_resource import SAML2ResponseResource
+from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
 
 logger = logging.getLogger(__name__)
 
@@ -27,3 +28,6 @@ class SAML2Resource(Resource):
         Resource.__init__(self)
         self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
         self.putChild(b"authn_response", SAML2ResponseResource(hs))
+
+
+__all__ = ["SAML2Resource"]
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index 1e8526e22e..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5e3..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..dfefeb7796
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+    """A resource which completes SSO registration
+
+    This resource gets mounted at /_synapse/client/sso_register, and is shown
+    after we collect username and/or consent for a new SSO user. It (finally) registers
+    the user, and confirms redirect to the client
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+
+    async def _async_render_GET(self, request: Request) -> None:
+        try:
+            session_id = get_username_mapping_session_cookie_from_request(request)
+        except SynapseError as e:
+            logger.warning("Error fetching session cookie: %s", e)
+            self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+            return
+        await self._sso_handler.register_sso_user(request, session_id)
diff --git a/synapse/server.py b/synapse/server.py
index b017e3489f..9bdd3177d7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
 from synapse.federation.transport.client import TransportLayerClient
 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
 from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
+from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.acme import AcmeHandler
 from synapse.handlers.admin import AdminHandler
@@ -102,6 +103,7 @@ from synapse.notifier import Notifier
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
 from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.resource import ReplicationStreamer
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -127,6 +129,8 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
+    from txredisapi import RedisProtocol
+
     from synapse.handlers.oidc_handler import OidcHandler
     from synapse.handlers.saml_handler import SamlHandler
 
@@ -283,10 +287,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         """
         return self._reactor
 
-    def get_ip_from_request(self, request) -> str:
-        # X-Forwarded-For is handled by our custom request type.
-        return request.getClientIP()
-
     def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
         return domain_specific_string.domain == self.hostname
 
@@ -350,17 +350,47 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_simple_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client with no special configuration.
+        """
         return SimpleHttpClient(self)
 
     @cache_in_self
     def get_proxied_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client that uses configured HTTP(S) proxies.
+        """
+        return SimpleHttpClient(
+            self,
+            http_proxy=os.getenvb(b"http_proxy"),
+            https_proxy=os.getenvb(b"HTTPS_PROXY"),
+        )
+
+    @cache_in_self
+    def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
+        """
+        An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
+        based on the IP range blacklist/whitelist.
+        """
         return SimpleHttpClient(
             self,
+            ip_whitelist=self.config.ip_range_whitelist,
+            ip_blacklist=self.config.ip_range_blacklist,
             http_proxy=os.getenvb(b"http_proxy"),
             https_proxy=os.getenvb(b"HTTPS_PROXY"),
         )
 
     @cache_in_self
+    def get_federation_http_client(self) -> MatrixFederationHttpClient:
+        """
+        An HTTP client for federation.
+        """
+        tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
+            self.config
+        )
+        return MatrixFederationHttpClient(self, tls_client_options_factory)
+
+    @cache_in_self
     def get_room_creation_handler(self) -> RoomCreationHandler:
         return RoomCreationHandler(self)
 
@@ -475,7 +505,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return InitialSyncHandler(self)
 
     @cache_in_self
-    def get_profile_handler(self):
+    def get_profile_handler(self) -> ProfileHandler:
         return ProfileHandler(self)
 
     @cache_in_self
@@ -515,13 +545,6 @@ class HomeServer(metaclass=abc.ABCMeta):
         return PusherPool(self)
 
     @cache_in_self
-    def get_http_client(self) -> MatrixFederationHttpClient:
-        tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
-            self.config
-        )
-        return MatrixFederationHttpClient(self, tls_client_options_factory)
-
-    @cache_in_self
     def get_media_repository_resource(self) -> MediaRepositoryResource:
         # build the media repo resource. This indirects through the HomeServer
         # to ensure that we only have a single instance of
@@ -595,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         return StatsHandler(self)
 
     @cache_in_self
-    def get_spam_checker(self):
+    def get_spam_checker(self) -> SpamChecker:
         return SpamChecker(self)
 
     @cache_in_self
@@ -692,6 +715,37 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_module_api(self) -> ModuleApi:
         return ModuleApi(self, self.get_auth_handler())
 
+    @cache_in_self
+    def get_account_data_handler(self) -> AccountDataHandler:
+        return AccountDataHandler(self)
+
+    @cache_in_self
+    def get_external_cache(self) -> ExternalCache:
+        return ExternalCache(self)
+
+    @cache_in_self
+    def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+        if not self.config.redis.redis_enabled:
+            return None
+
+        # We only want to import redis module if we're using it, as we have
+        # `txredisapi` as an optional dependency.
+        from synapse.replication.tcp.redis import lazyConnection
+
+        logger.info(
+            "Connecting to redis (host=%r port=%r) for external cache",
+            self.config.redis_host,
+            self.config.redis_port,
+        )
+
+        return lazyConnection(
+            hs=self,
+            host=self.config.redis_host,
+            port=self.config.redis_port,
+            password=self.config.redis.redis_password,
+            reconnect=True,
+        )
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 2258d306d9..8dd01fce76 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -42,6 +42,7 @@ class ResourceLimitsServerNotices:
         self._auth = hs.get_auth()
         self._config = hs.config
         self._resouce_limited = False
+        self._account_data_handler = hs.get_account_data_handler()
         self._message_handler = hs.get_message_handler()
         self._state = hs.get_state_handler()
 
@@ -177,7 +178,7 @@ class ResourceLimitsServerNotices:
                 # tag already present, nothing to do here
                 need_to_set_tag = False
         if need_to_set_tag:
-            max_id = await self._store.add_tag_to_room(
+            max_id = await self._account_data_handler.add_tag_to_room(
                 user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
             )
             self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 100dbd5e2c..c46b2f047d 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -35,6 +35,7 @@ class ServerNoticesManager:
 
         self._store = hs.get_datastore()
         self._config = hs.config
+        self._account_data_handler = hs.get_account_data_handler()
         self._room_creation_handler = hs.get_room_creation_handler()
         self._room_member_handler = hs.get_room_member_handler()
         self._event_creation_handler = hs.get_event_creation_handler()
@@ -163,7 +164,7 @@ class ServerNoticesManager:
         )
         room_id = info["room_id"]
 
-        max_id = await self._store.add_tag_to_room(
+        max_id = await self._account_data_handler.add_tag_to_room(
             user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
         )
         self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1fa3b280b4..3bd9ff8ca0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
+            entry = None
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
                 current_state_ids=state_ids_before_event,
             )
 
-            # XXX: can we update the state cache entry for the new state group? or
-            # could we set a flag on resolve_state_groups_for_events to tell it to
-            # always make a state group?
+            # Assign the new state group to the cached state entry.
+            #
+            # Note that this can race in that we could generate multiple state
+            # groups for the same state entry, but that is just inefficient
+            # rather than dangerous.
+            if entry and entry.state_group is None:
+                entry.state_group = state_group_before_event
 
         #
         # now if it's not a state event, we're done
@@ -783,7 +788,7 @@ class StateResolutionStore:
         )
 
     def get_auth_chain_difference(
-        self, state_sets: List[Set[str]]
+        self, room_id: str, state_sets: List[Set[str]]
     ) -> Awaitable[Set[str]]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
@@ -796,4 +801,4 @@ class StateResolutionStore:
             An awaitable that resolves to a set of event IDs.
         """
 
-        return self.store.get_auth_chain_difference(state_sets)
+        return self.store.get_auth_chain_difference(room_id, state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..e585954bd8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
 
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
-    auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+    auth_diff = await _get_auth_chain_difference(
+        room_id, state_sets, event_map, state_res_store
+    )
 
     full_conflicted_set = set(
         itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
 
 
 async def _get_auth_chain_difference(
+    room_id: str,
     state_sets: Sequence[StateMap[str]],
     event_map: Dict[str, EventBase],
     state_res_store: "synapse.state.StateResolutionStore",
@@ -252,9 +255,90 @@ async def _get_auth_chain_difference(
         Set of event IDs
     """
 
+    # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+    # all events passed to it (and their auth chains) have been persisted
+    # previously. This is not the case for any events in the `event_map`, and so
+    # we need to manually handle those events.
+    #
+    # We do this by:
+    #   1. calculating the auth chain difference for the state sets based on the
+    #      events in `event_map` alone
+    #   2. replacing any events in the state_sets that are also in `event_map`
+    #      with their auth events (recursively), and then calling
+    #      `store.get_auth_chain_difference` as normal
+    #   3. adding the results of 1 and 2 together.
+
+    # Map from event ID in `event_map` to their auth event IDs, and their auth
+    # event IDs if they appear in the `event_map`. This is the intersection of
+    # the event's auth chain with the events in the `event_map` *plus* their
+    # auth event IDs.
+    events_to_auth_chain = {}  # type: Dict[str, Set[str]]
+    for event in event_map.values():
+        chain = {event.event_id}
+        events_to_auth_chain[event.event_id] = chain
+
+        to_search = [event]
+        while to_search:
+            for auth_id in to_search.pop().auth_event_ids():
+                chain.add(auth_id)
+                auth_event = event_map.get(auth_id)
+                if auth_event:
+                    to_search.append(auth_event)
+
+    # We now a) calculate the auth chain difference for the unpersisted events
+    # and b) work out the state sets to pass to the store.
+    #
+    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # much simpler calculation.
+    if event_map:
+        # The list of state sets to pass to the store, where each state set is a set
+        # of the event ids making up the state. This is similar to `state_sets`,
+        # except that (a) we only have event ids, not the complete
+        # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+        # unpersisted events and replaced them with the persisted events in
+        # their auth chain.
+        state_sets_ids = []  # type: List[Set[str]]
+
+        # For each state set, the unpersisted event IDs reachable (by their auth
+        # chain) from the events in that set.
+        unpersisted_set_ids = []  # type: List[Set[str]]
+
+        for state_set in state_sets:
+            set_ids = set()  # type: Set[str]
+            state_sets_ids.append(set_ids)
+
+            unpersisted_ids = set()  # type: Set[str]
+            unpersisted_set_ids.append(unpersisted_ids)
+
+            for event_id in state_set.values():
+                event_chain = events_to_auth_chain.get(event_id)
+                if event_chain is not None:
+                    # We have an event in `event_map`. We add all the auth
+                    # events that it references (that aren't also in `event_map`).
+                    set_ids.update(e for e in event_chain if e not in event_map)
+
+                    # We also add the full chain of unpersisted event IDs
+                    # referenced by this state set, so that we can work out the
+                    # auth chain difference of the unpersisted events.
+                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                else:
+                    set_ids.add(event_id)
+
+        # The auth chain difference of the unpersisted events of the state sets
+        # is calculated by taking the difference between the union and
+        # intersections.
+        union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+        intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+        difference_from_event_map = union - intersection  # type: Collection[str]
+    else:
+        difference_from_event_map = ()
+        state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
     difference = await state_res_store.get_auth_chain_difference(
-        [set(state_set.values()) for state_set in state_sets]
+        room_id, state_sets_ids
     )
+    difference.update(difference_from_event_map)
 
     return difference
 
@@ -574,7 +658,7 @@ async def _get_mainline_depth_for_event(
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
     while tmp_event:
-        depth = mainline_map.get(event.event_id)
+        depth = mainline_map.get(tmp_event.event_id)
         if depth is not None:
             return depth
 
diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css
index 83e4f6abc8..dd76714a92 100644
--- a/synapse/static/client/login/style.css
+++ b/synapse/static/client/login/style.css
@@ -31,6 +31,11 @@ form {
     margin: 10px 0 0 0;
 }
 
+ul.radiobuttons {
+    text-align: left;
+    list-style: none;
+}
+
 /*
  * Add some padding to the viewport.
  */
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index bbff3c8d5b..c0d9d1240f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
 data stores associated with them (e.g. the schema version tables), which are
 stored in `synapse.storage.schema`.
 """
+from typing import TYPE_CHECKING
 
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main import DataStore
@@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
 from synapse.storage.purge_events import PurgeEventsStorage
 from synapse.storage.state import StateGroupStorage
 
-__all__ = ["DataStores", "DataStore"]
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
 
 
 class Storage:
     """The high level interfaces for talking to various storage layers.
     """
 
-    def __init__(self, hs, stores: Databases):
+    def __init__(self, hs: "HomeServer", stores: Databases):
         # We include the main data store here mainly so that we don't have to
         # rewrite all the existing code to split it into high vs low level
         # interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded1b..a25c4093bc 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@
 import logging
 import random
 from abc import ABCMeta
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
 
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
-from synapse.types import Collection, get_domain_from_id
+from synapse.storage.types import Connection
+from synapse.types import Collection, StreamToken, get_domain_from_id
 from synapse.util import json_decoder
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
         self.db_pool = database
         self.rand = random.SystemRandom()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: StreamToken,
+        rows: Iterable[Any],
+    ) -> None:
         pass
 
-    def _invalidate_state_caches(self, room_id, members_changed):
+    def _invalidate_state_caches(
+        self, room_id: str, members_changed: Iterable[str]
+    ) -> None:
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
 
         Args:
-            room_id (str): Room where state changed
-            members_changed (iterable[str]): The user_ids of members that have
-                changed
+            room_id: Room where state changed
+            members_changed: The user_ids of members that have changed
         """
         for host in {get_domain_from_id(u) for u in members_changed}:
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
-    ):
+    ) -> None:
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
             cache.invalidate(tuple(key))
 
 
-def db_to_json(db_content):
+def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
     """
     Take some data from a database row and return a JSON-decoded object.
 
     Args:
-        db_content (memoryview|buffer|bytes|bytearray|unicode)
+        db_content: The JSON-encoded contents from the database.
+
+    Returns:
+        The object decoded from JSON.
     """
     # psycopg2 on Python 3 returns memoryview objects, which we need to
     # cast to bytes to decode
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 810721ebe9..29b8ca676a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,29 +12,34 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.types import Connection
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 
 from . import engines
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.storage.database import DatabasePool, LoggingTransaction
+
 logger = logging.getLogger(__name__)
 
 
 class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         self.name = name
         self.total_item_count = 0
-        self.total_duration_ms = 0
-        self.avg_item_count = 0
-        self.avg_duration_ms = 0
+        self.total_duration_ms = 0.0
+        self.avg_item_count = 0.0
+        self.avg_duration_ms = 0.0
 
-    def update(self, item_count, duration_ms):
+    def update(self, item_count: int, duration_ms: float) -> None:
         """Update the stats after doing an update"""
         self.total_item_count += item_count
         self.total_duration_ms += duration_ms
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
         self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
         self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
 
-    def average_items_per_ms(self):
+    def average_items_per_ms(self) -> Optional[float]:
         """An estimate of how long it takes to do a single update.
         Returns:
             A duration in ms as a float
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
             # changes in how long the update process takes.
             return float(self.avg_item_count) / float(self.avg_duration_ms)
 
-    def total_items_per_ms(self):
+    def total_items_per_ms(self) -> Optional[float]:
         """An estimate of how long it takes to do a single update.
         Returns:
             A duration in ms as a float
@@ -83,21 +88,25 @@ class BackgroundUpdater:
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, hs, database):
+    def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
 
         # if a background update is currently running, its name.
         self._current_background_update = None  # type: Optional[str]
 
-        self._background_update_performance = {}
-        self._background_update_handlers = {}
+        self._background_update_performance = (
+            {}
+        )  # type: Dict[str, BackgroundUpdatePerformance]
+        self._background_update_handlers = (
+            {}
+        )  # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
         self._all_done = False
 
-    def start_doing_background_updates(self):
+    def start_doing_background_updates(self) -> None:
         run_as_background_process("background_updates", self.run_background_updates)
 
-    async def run_background_updates(self, sleep=True):
+    async def run_background_updates(self, sleep: bool = True) -> None:
         logger.info("Starting background schema updates")
         while True:
             if sleep:
@@ -148,7 +157,7 @@ class BackgroundUpdater:
 
         return False
 
-    async def has_completed_background_update(self, update_name) -> bool:
+    async def has_completed_background_update(self, update_name: str) -> bool:
         """Check if the given background update has finished running.
         """
         if self._all_done:
@@ -173,8 +182,7 @@ class BackgroundUpdater:
         Returns once some amount of work is done.
 
         Args:
-            desired_duration_ms(float): How long we want to spend
-                updating.
+            desired_duration_ms: How long we want to spend updating.
         Returns:
             True if we have finished running all the background updates, otherwise False
         """
@@ -220,6 +228,7 @@ class BackgroundUpdater:
         return False
 
     async def _do_background_update(self, desired_duration_ms: float) -> int:
+        assert self._current_background_update is not None
         update_name = self._current_background_update
         logger.info("Starting update batch on background update '%s'", update_name)
 
@@ -273,7 +282,11 @@ class BackgroundUpdater:
 
         return len(self._background_update_performance)
 
-    def register_background_update_handler(self, update_name, update_handler):
+    def register_background_update_handler(
+        self,
+        update_name: str,
+        update_handler: Callable[[JsonDict, int], Awaitable[int]],
+    ):
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -287,12 +300,12 @@ class BackgroundUpdater:
         The handler is responsible for updating the progress of the update.
 
         Args:
-            update_name(str): The name of the update that this code handles.
-            update_handler(function): The function that does the update.
+            update_name: The name of the update that this code handles.
+            update_handler: The function that does the update.
         """
         self._background_update_handlers[update_name] = update_handler
 
-    def register_noop_background_update(self, update_name):
+    def register_noop_background_update(self, update_name: str) -> None:
         """Register a noop handler for a background update.
 
         This is useful when we previously did a background update, but no
@@ -302,10 +315,10 @@ class BackgroundUpdater:
         also be called to clear the update.
 
         Args:
-            update_name (str): Name of update
+            update_name: Name of update
         """
 
-        async def noop_update(progress, batch_size):
+        async def noop_update(progress: JsonDict, batch_size: int) -> int:
             await self._end_background_update(update_name)
             return 1
 
@@ -313,14 +326,14 @@ class BackgroundUpdater:
 
     def register_background_index_update(
         self,
-        update_name,
-        index_name,
-        table,
-        columns,
-        where_clause=None,
-        unique=False,
-        psql_only=False,
-    ):
+        update_name: str,
+        index_name: str,
+        table: str,
+        columns: Iterable[str],
+        where_clause: Optional[str] = None,
+        unique: bool = False,
+        psql_only: bool = False,
+    ) -> None:
         """Helper for store classes to do a background index addition
 
         To use:
@@ -332,19 +345,19 @@ class BackgroundUpdater:
         2. In the Store constructor, call this method
 
         Args:
-            update_name (str): update_name to register for
-            index_name (str): name of index to add
-            table (str): table to add index to
-            columns (list[str]): columns/expressions to include in index
-            unique (bool): true to make a UNIQUE index
+            update_name: update_name to register for
+            index_name: name of index to add
+            table: table to add index to
+            columns: columns/expressions to include in index
+            unique: true to make a UNIQUE index
             psql_only: true to only create this index on psql databases (useful
                 for virtual sqlite tables)
         """
 
-        def create_index_psql(conn):
+        def create_index_psql(conn: Connection) -> None:
             conn.rollback()
             # postgres insists on autocommit for the index
-            conn.set_session(autocommit=True)
+            conn.set_session(autocommit=True)  # type: ignore
 
             try:
                 c = conn.cursor()
@@ -371,9 +384,9 @@ class BackgroundUpdater:
                 logger.debug("[SQL] %s", sql)
                 c.execute(sql)
             finally:
-                conn.set_session(autocommit=False)
+                conn.set_session(autocommit=False)  # type: ignore
 
-        def create_index_sqlite(conn):
+        def create_index_sqlite(conn: Connection) -> None:
             # Sqlite doesn't support concurrent creation of indexes.
             #
             # We don't use partial indices on SQLite as it wasn't introduced
@@ -399,7 +412,7 @@ class BackgroundUpdater:
             c.execute(sql)
 
         if isinstance(self.db_pool.engine, engines.PostgresEngine):
-            runner = create_index_psql
+            runner = create_index_psql  # type: Optional[Callable[[Connection], None]]
         elif psql_only:
             runner = None
         else:
@@ -433,7 +446,9 @@ class BackgroundUpdater:
             "background_updates", keyvalues={"update_name": update_name}
         )
 
-    async def _background_update_progress(self, update_name: str, progress: dict):
+    async def _background_update_progress(
+        self, update_name: str, progress: dict
+    ) -> None:
         """Update the progress of a background update
 
         Args:
@@ -441,20 +456,22 @@ class BackgroundUpdater:
             progress: The progress of the update.
         """
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "background_update_progress",
             self._background_update_progress_txn,
             update_name,
             progress,
         )
 
-    def _background_update_progress_txn(self, txn, update_name, progress):
+    def _background_update_progress_txn(
+        self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
+    ) -> None:
         """Update the progress of a background update
 
         Args:
-            txn(cursor): The transaction.
-            update_name(str): The name of the background update task
-            progress(dict): The progress of the update.
+            txn: The transaction.
+            update_name: The name of the background update task
+            progress: The progress of the update.
         """
 
         progress_json = json_encoder.encode(progress)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d1b5760c2c..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -42,7 +42,6 @@ from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
-    LoggingContextOrSentinel,
     current_context,
     make_deferred_yieldable,
 )
@@ -50,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import Collection
 
 # python 3 does not have a maximum int value
@@ -180,6 +180,9 @@ class LoggingDatabaseConnection:
 _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
 
 
+R = TypeVar("R")
+
+
 class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
@@ -259,13 +262,32 @@ class LoggingTransaction:
         return self.txn.description
 
     def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+        """Similar to `executemany`, except `txn.rowcount` will not be correct
+        afterwards.
+
+        More efficient than `executemany` on PostgreSQL
+        """
+
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
-            for val in args:
-                self.execute(sql, val)
+            self.executemany(sql, args)
+
+    def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+        """Corresponds to psycopg2.extras.execute_values. Only available when
+        using postgres.
+
+        Always sets fetch=True when caling `execute_values`, so will return the
+        results.
+        """
+        assert isinstance(self.database_engine, PostgresEngine)
+        from psycopg2.extras import execute_values  # type: ignore
+
+        return self._do_execute(
+            lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+        )
 
     def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
@@ -277,7 +299,7 @@ class LoggingTransaction:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql: str, *args: Any) -> None:
+    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -348,9 +370,6 @@ class PerformanceCounters:
         return top_n_counters
 
 
-R = TypeVar("R")
-
-
 class DatabasePool:
     """Wraps a single physical database and connection pool.
 
@@ -399,6 +418,16 @@ class DatabasePool:
                 self._check_safe_to_upsert,
             )
 
+        # We define this sequence here so that it can be referenced from both
+        # the DataStore and PersistEventStore.
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self.event_chain_id_gen = build_sequence_generator(
+            engine, get_chain_id_txn, "event_auth_chain_id"
+        )
+
     def is_running(self) -> bool:
         """Is the database pool currently running
         """
@@ -671,12 +700,15 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
-        if not parent_context:
+        curr_context = current_context()
+        if not curr_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
             parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
 
         start_time = monotonic_time()
 
@@ -861,7 +893,7 @@ class DatabasePool:
             ", ".join("?" for _ in keys[0]),
         )
 
-        txn.executemany(sql, vals)
+        txn.execute_batch(sql, vals)
 
     async def simple_upsert(
         self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 43660ec4fb..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
 from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
 from .filtering import FilteringStore
 from .group_server import GroupServerStore
 from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
     UIAuthStore,
     CacheInvalidationWorkerStore,
     ServerMetricsStore,
+    EventForwardExtremitiesStore,
 ):
     def __init__(self, database: DatabasePool, db_conn, hs):
         self.hs = hs
@@ -127,9 +129,6 @@ class DataStore(
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
-        self._device_inbox_id_gen = StreamIdGenerator(
-            db_conn, "device_inbox", "stream_id"
-        )
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
@@ -149,9 +148,6 @@ class DataStore(
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-        self._pushers_id_gen = StreamIdGenerator(
-            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
-        )
         self._group_updates_id_gen = StreamIdGenerator(
             db_conn, "local_group_updates", "stream_id"
         )
@@ -166,9 +162,13 @@ class DataStore(
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )
@@ -192,36 +192,6 @@ class DataStore(
             prefilled_cache=presence_cache_prefill,
         )
 
-        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
-        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_inbox",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            min_device_inbox_id,
-            prefilled_cache=device_inbox_prefill,
-        )
-        # The federation outbox and the local device inbox uses the same
-        # stream_id generator.
-        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_federation_outbox",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            min_device_outbox_id,
-            prefilled_cache=device_outbox_prefill,
-        )
-
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
             "DeviceListStreamChangeCache", device_list_max
@@ -342,12 +312,13 @@ class DataStore(
             filters = []
             args = [self.hs.config.server_name]
 
+            # `name` is in database already in lower case
             if name:
-                filters.append("(name LIKE ? OR displayname LIKE ?)")
-                args.extend(["@%" + name + "%:%", "%" + name + "%"])
+                filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
+                args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
             elif user_id:
                 filters.append("name LIKE ?")
-                args.extend(["%" + user_id + "%"])
+                args.extend(["%" + user_id.lower() + "%"])
 
             if not guests:
                 filters.append("is_guest = 0")
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 49ee23470d..a277a1ef13 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,30 +14,75 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import _CacheContext, cached
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
+class AccountDataWorkerStore(SQLBaseStore):
     """This is an abstract base class where subclasses must implement
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_account_data = (
+                self._instance_name in hs.config.worker.writers.account_data
+            )
+
+            self._account_data_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="account_data",
+                instance_name=self._instance_name,
+                tables=[
+                    ("room_account_data", "instance_name", "stream_id"),
+                    ("room_tags_revisions", "instance_name", "stream_id"),
+                    ("account_data", "instance_name", "stream_id"),
+                ],
+                sequence_name="account_data_sequence",
+                writers=hs.config.worker.writers.account_data,
+            )
+        else:
+            self._can_write_to_account_data = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.account_data:
+                self._account_data_id_gen = StreamIdGenerator(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+            else:
+                self._account_data_id_gen = SlavedIdTracker(
+                    db_conn,
+                    "room_account_data",
+                    "stream_id",
+                    extra_tables=[("room_tags_revisions", "stream_id")],
+                )
+
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
@@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         super().__init__(database, db_conn, hs)
 
-    @abc.abstractmethod
-    def get_max_account_data_stream_id(self):
+    def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._account_data_id_gen.get_current_token()
 
     @cached()
     async def get_account_data_for_user(
@@ -287,46 +331,46 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
-    @cached(num_args=2, cache_context=True, max_entries=5000)
-    async def is_ignored_by(
-        self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
-    ) -> bool:
-        ignored_account_data = await self.get_global_account_data_by_type_for_user(
-            AccountDataTypes.IGNORED_USER_LIST,
-            ignorer_user_id,
-            on_invalidate=cache_context.invalidate,
-        )
-        if not ignored_account_data:
-            return False
-
-        try:
-            return ignored_user_id in ignored_account_data.get("ignored_users", {})
-        except TypeError:
-            # The type of the ignored_users field is invalid.
-            return False
+    @cached(max_entries=5000, iterable=True)
+    async def ignored_by(self, user_id: str) -> Set[str]:
+        """
+        Get users which ignore the given user.
 
+        Params:
+            user_id: The user ID which might be ignored.
 
-class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn,
-            "account_data_max_stream_id",
-            "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
+        Return:
+            The user IDs which ignore the given user.
+        """
+        return set(
+            await self.db_pool.simple_select_onecol(
+                table="ignored_users",
+                keyvalues={"ignored_user_id": user_id},
+                retcol="ignorer_user_id",
+                desc="ignored_by",
+            )
         )
 
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self) -> int:
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            The maximum stream ID.
-        """
-        return self._account_data_id_gen.get_current_token()
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        elif stream_name == AccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                if not row.room_id:
+                    self.get_global_account_data_by_type_for_user.invalidate(
+                        (row.data_type, row.user_id)
+                    )
+                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
+                self.get_account_data_for_room_and_type.invalidate(
+                    (row.user_id, row.room_id, row.data_type)
+                )
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -342,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         async with self._account_data_id_gen.get_next() as next_id:
@@ -360,14 +406,6 @@ class AccountDataStore(AccountDataWorkerStore):
                 lock=False,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
@@ -390,32 +428,18 @@ class AccountDataStore(AccountDataWorkerStore):
         Returns:
             The maximum stream ID.
         """
-        content_json = json_encoder.encode(content)
+        assert self._can_write_to_account_data
 
         async with self._account_data_id_gen.get_next() as next_id:
-            # no need to lock here as account_data has a unique constraint on
-            # (user_id, account_data_type) so simple_upsert will retry if
-            # there is a conflict.
-            await self.db_pool.simple_upsert(
-                desc="add_user_account_data",
-                table="account_data",
-                keyvalues={"user_id": user_id, "account_data_type": account_data_type},
-                values={"stream_id": next_id, "content": content_json},
-                lock=False,
+            await self.db_pool.runInteraction(
+                "add_user_account_data",
+                self._add_account_data_for_user,
+                next_id,
+                user_id,
+                account_data_type,
+                content,
             )
 
-            # it's theoretically possible for the above to succeed and the
-            # below to fail - in which case we might reuse a stream id on
-            # restart, and the above update might not get propagated. That
-            # doesn't sound any worse than the whole update getting lost,
-            # which is what would happen if we combined the two into one
-            # transaction.
-            #
-            # Note: This is only here for backwards compat to allow admins to
-            # roll back to a previous Synapse version. Next time we update the
-            # database version we can remove this table.
-            await self._update_max_stream_id(next_id)
-
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
@@ -424,23 +448,71 @@ class AccountDataStore(AccountDataWorkerStore):
 
         return self._account_data_id_gen.get_current_token()
 
-    async def _update_max_stream_id(self, next_id: int) -> None:
-        """Update the max stream_id
+    def _add_account_data_for_user(
+        self,
+        txn,
+        next_id: int,
+        user_id: str,
+        account_data_type: str,
+        content: JsonDict,
+    ) -> None:
+        content_json = json_encoder.encode(content)
 
-        Args:
-            next_id: The the revision to advance to.
-        """
+        # no need to lock here as account_data has a unique constraint on
+        # (user_id, account_data_type) so simple_upsert will retry if
+        # there is a conflict.
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="account_data",
+            keyvalues={"user_id": user_id, "account_data_type": account_data_type},
+            values={"stream_id": next_id, "content": content_json},
+            lock=False,
+        )
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
+        # Ignored users get denormalized into a separate table as an optimisation.
+        if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
+            return
 
-        def _update(txn):
-            update_max_id_sql = (
-                "UPDATE account_data_max_stream_id"
-                " SET stream_id = ?"
-                " WHERE stream_id < ?"
+        # Insert / delete to sync the list of ignored users.
+        previously_ignored_users = set(
+            self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="ignored_users",
+                keyvalues={"ignorer_user_id": user_id},
+                retcol="ignored_user_id",
             )
-            txn.execute(update_max_id_sql, (next_id, next_id))
+        )
+
+        # If the data is invalid, no one is ignored.
+        ignored_users_content = content.get("ignored_users", {})
+        if isinstance(ignored_users_content, dict):
+            currently_ignored_users = set(ignored_users_content)
+        else:
+            currently_ignored_users = set()
+
+        # Delete entries which are no longer ignored.
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="ignored_users",
+            column="ignored_user_id",
+            iterable=previously_ignored_users - currently_ignored_users,
+            keyvalues={"ignorer_user_id": user_id},
+        )
 
-        await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
+        # Add entries which are newly ignored.
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="ignored_users",
+            values=[
+                {"ignorer_user_id": user_id, "ignored_user_id": u}
+                for u in currently_ignored_users - previously_ignored_users
+            ],
+        )
+
+        # Invalidate the cache for any ignored users which were added or removed.
+        for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
+            self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 339bd691a4..ea1e8fb580 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,11 +14,12 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.types import UserID
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -406,6 +407,34 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             "_prune_old_user_ips", _prune_old_user_ips_txn
         )
 
+    async def get_last_client_ip_by_device(
+        self, user_id: str, device_id: Optional[str]
+    ) -> Dict[Tuple[str, str], dict]:
+        """For each device_id listed, give the user_ip it was last seen on.
+
+        The result might be slightly out of date as client IPs are inserted in batches.
+
+        Args:
+            user_id: The user to fetch devices for.
+            device_id: If None fetches all devices for the user
+
+        Returns:
+            A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+            keys giving the column names from the devices table.
+        """
+
+        keyvalues = {"user_id": user_id}
+        if device_id is not None:
+            keyvalues["device_id"] = device_id
+
+        res = await self.db_pool.simple_select_list(
+            table="devices",
+            keyvalues=keyvalues,
+            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        )
+
+        return {(d["user_id"], d["device_id"]): d for d in res}
+
 
 class ClientIpStore(ClientIpWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -469,43 +498,35 @@ class ClientIpStore(ClientIpWorkerStore):
         for entry in to_update.items():
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
-            try:
-                self.db_pool.simple_upsert_txn(
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="user_ips",
+                keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
+                values={
+                    "user_agent": user_agent,
+                    "device_id": device_id,
+                    "last_seen": last_seen,
+                },
+                lock=False,
+            )
+
+            # Technically an access token might not be associated with
+            # a device so we need to check.
+            if device_id:
+                # this is always an update rather than an upsert: the row should
+                # already exist, and if it doesn't, that may be because it has been
+                # deleted, and we don't want to re-create it.
+                self.db_pool.simple_update_txn(
                     txn,
-                    table="user_ips",
-                    keyvalues={
-                        "user_id": user_id,
-                        "access_token": access_token,
-                        "ip": ip,
-                    },
-                    values={
+                    table="devices",
+                    keyvalues={"user_id": user_id, "device_id": device_id},
+                    updatevalues={
                         "user_agent": user_agent,
-                        "device_id": device_id,
                         "last_seen": last_seen,
+                        "ip": ip,
                     },
-                    lock=False,
                 )
 
-                # Technically an access token might not be associated with
-                # a device so we need to check.
-                if device_id:
-                    # this is always an update rather than an upsert: the row should
-                    # already exist, and if it doesn't, that may be because it has been
-                    # deleted, and we don't want to re-create it.
-                    self.db_pool.simple_update_txn(
-                        txn,
-                        table="devices",
-                        keyvalues={"user_id": user_id, "device_id": device_id},
-                        updatevalues={
-                            "user_agent": user_agent,
-                            "last_seen": last_seen,
-                            "ip": ip,
-                        },
-                    )
-            except Exception as e:
-                # Failed to upsert, log and continue
-                logger.error("Failed to insert client IP %r: %r", entry, e)
-
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
     ) -> Dict[Tuple[str, str], dict]:
@@ -519,18 +540,9 @@ class ClientIpStore(ClientIpWorkerStore):
             A dictionary mapping a tuple of (user_id, device_id) to dicts, with
             keys giving the column names from the devices table.
         """
+        ret = await super().get_last_client_ip_by_device(user_id, device_id)
 
-        keyvalues = {"user_id": user_id}
-        if device_id is not None:
-            keyvalues["device_id"] = device_id
-
-        res = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues=keyvalues,
-            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
-        )
-
-        ret = {(d["user_id"], d["device_id"]): d for d in res}
+        # Update what is retrieved from the database with data which is pending insertion.
         for key in self._batch_row_update:
             uid, access_token, ip = key
             if uid == user_id:
@@ -546,7 +558,9 @@ class ClientIpStore(ClientIpWorkerStore):
                     }
         return ret
 
-    async def get_user_ip_and_agents(self, user):
+    async def get_user_ip_and_agents(
+        self, user: UserID
+    ) -> List[Dict[str, Union[str, int]]]:
         user_id = user.to_string()
         results = {}
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index d42faa3f1f..31f70ac5ef 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -17,15 +17,98 @@ import logging
 from typing import List, Tuple
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
-from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
 
 
 class DeviceInboxWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+        # Map of (user_id, device_id) to the last stream_id that has been
+        # deleted up to. This is so that we can no op deletions.
+        self._last_device_delete_cache = ExpiringCache(
+            cache_name="last_device_delete_cache",
+            clock=self._clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+        )
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_device = (
+                self._instance_name in hs.config.worker.writers.to_device
+            )
+
+            self._device_inbox_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="to_device",
+                instance_name=self._instance_name,
+                tables=[("device_inbox", "instance_name", "stream_id")],
+                sequence_name="device_inbox_sequence",
+                writers=hs.config.worker.writers.to_device,
+            )
+        else:
+            self._can_write_to_device = True
+            self._device_inbox_id_gen = StreamIdGenerator(
+                db_conn, "device_inbox", "stream_id"
+            )
+
+        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+        device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_inbox",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_inbox_stream_cache = StreamChangeCache(
+            "DeviceInboxStreamChangeCache",
+            min_device_inbox_id,
+            prefilled_cache=device_inbox_prefill,
+        )
+
+        # The federation outbox and the local device inbox uses the same
+        # stream_id generator.
+        device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_federation_outbox",
+            entity_column="destination",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_federation_outbox_stream_cache = StreamChangeCache(
+            "DeviceFederationOutboxStreamChangeCache",
+            min_device_outbox_id,
+            prefilled_cache=device_outbox_prefill,
+        )
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == ToDeviceStream.NAME:
+            self._device_inbox_id_gen.advance(instance_name, token)
+            for row in rows:
+                if row.entity.startswith("@"):
+                    self._device_inbox_stream_cache.entity_has_changed(
+                        row.entity, token
+                    )
+                else:
+                    self._device_federation_outbox_stream_cache.entity_has_changed(
+                        row.entity, token
+                    )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
+
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
@@ -278,52 +361,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
 
-
-class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
-    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self.db_pool.updates.register_background_index_update(
-            "device_inbox_stream_index",
-            index_name="device_inbox_stream_id_user_id",
-            table="device_inbox",
-            columns=["stream_id", "user_id"],
-        )
-
-        self.db_pool.updates.register_background_update_handler(
-            self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
-        )
-
-    async def _background_drop_index_device_inbox(self, progress, batch_size):
-        def reindex_txn(conn):
-            txn = conn.cursor()
-            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
-            txn.close()
-
-        await self.db_pool.runWithConnection(reindex_txn)
-
-        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
-        return 1
-
-
-class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
-    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        # Map of (user_id, device_id) to the last stream_id that has been
-        # deleted up to. This is so that we can no op deletions.
-        self._last_device_delete_cache = ExpiringCache(
-            cache_name="last_device_delete_cache",
-            clock=self._clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-        )
-
     @trace
     async def add_messages_to_device_inbox(
         self,
@@ -342,6 +379,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             The new stream_id.
         """
 
+        assert self._can_write_to_device
+
         def add_messages_txn(txn, now_ms, stream_id):
             # Add the local messages directly to the local inbox.
             self._add_messages_to_local_device_inbox_txn(
@@ -351,16 +390,20 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             # Add the remote messages to the federation outbox.
             # We'll send them to a remote server when we next send a
             # federation transaction to that destination.
-            sql = (
-                "INSERT INTO device_federation_outbox"
-                " (destination, stream_id, queued_ts, messages_json)"
-                " VALUES (?,?,?,?)"
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="device_federation_outbox",
+                values=[
+                    {
+                        "destination": destination,
+                        "stream_id": stream_id,
+                        "queued_ts": now_ms,
+                        "messages_json": json_encoder.encode(edu),
+                        "instance_name": self._instance_name,
+                    }
+                    for destination, edu in remote_messages_by_destination.items()
+                ],
             )
-            rows = []
-            for destination, edu in remote_messages_by_destination.items():
-                edu_json = json_encoder.encode(edu)
-                rows.append((destination, stream_id, now_ms, edu_json))
-            txn.executemany(sql, rows)
 
         async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
@@ -379,6 +422,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     async def add_messages_from_remote_to_device_inbox(
         self, origin: str, message_id: str, local_messages_by_user_then_device: dict
     ) -> int:
+        assert self._can_write_to_device
+
         def add_messages_txn(txn, now_ms, stream_id):
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
@@ -427,38 +472,45 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     def _add_messages_to_local_device_inbox_txn(
         self, txn, stream_id, messages_by_user_then_device
     ):
+        assert self._can_write_to_device
+
         local_by_user_then_device = {}
         for user_id, messages_by_device in messages_by_user_then_device.items():
             messages_json_for_user = {}
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
-                sql = "SELECT device_id FROM devices WHERE user_id = ?"
-                txn.execute(sql, (user_id,))
+                devices = self.db_pool.simple_select_onecol_txn(
+                    txn,
+                    table="devices",
+                    keyvalues={"user_id": user_id},
+                    retcol="device_id",
+                )
+
                 message_json = json_encoder.encode(messages_by_device["*"])
-                for row in txn:
+                for device_id in devices:
                     # Add the message for all devices for this user on this
                     # server.
-                    device = row[0]
-                    messages_json_for_user[device] = message_json
+                    messages_json_for_user[device_id] = message_json
             else:
                 if not devices:
                     continue
 
-                clause, args = make_in_list_sql_clause(
-                    txn.database_engine, "device_id", devices
+                rows = self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="devices",
+                    keyvalues={"user_id": user_id},
+                    column="device_id",
+                    iterable=devices,
+                    retcols=("device_id",),
                 )
-                sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
 
-                # TODO: Maybe this needs to be done in batches if there are
-                # too many local devices for a given user.
-                txn.execute(sql, [user_id] + list(args))
-                for row in txn:
+                for row in rows:
                     # Only insert into the local inbox if the device exists on
                     # this server
-                    device = row[0]
-                    message_json = json_encoder.encode(messages_by_device[device])
-                    messages_json_for_user[device] = message_json
+                    device_id = row["device_id"]
+                    message_json = json_encoder.encode(messages_by_device[device_id])
+                    messages_json_for_user[device_id] = message_json
 
             if messages_json_for_user:
                 local_by_user_then_device[user_id] = messages_json_for_user
@@ -466,14 +518,52 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
         if not local_by_user_then_device:
             return
 
-        sql = (
-            "INSERT INTO device_inbox"
-            " (user_id, device_id, stream_id, message_json)"
-            " VALUES (?,?,?,?)"
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="device_inbox",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "stream_id": stream_id,
+                    "message_json": message_json,
+                    "instance_name": self._instance_name,
+                }
+                for user_id, messages_by_device in local_by_user_then_device.items()
+                for device_id, message_json in messages_by_device.items()
+            ],
         )
-        rows = []
-        for user_id, messages_by_device in local_by_user_then_device.items():
-            for device_id, message_json in messages_by_device.items():
-                rows.append((user_id, device_id, stream_id, message_json))
 
-        txn.executemany(sql, rows)
+
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
+    DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_index_update(
+            "device_inbox_stream_index",
+            index_name="device_inbox_stream_id_user_id",
+            table="device_inbox",
+            columns=["stream_id", "user_id"],
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
+        )
+
+    async def _background_drop_index_device_inbox(self, progress, batch_size):
+        def reindex_txn(conn):
+            txn = conn.cursor()
+            txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+            txn.close()
+
+        await self.db_pool.runWithConnection(reindex_txn)
+
+        await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+        return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+    pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dfb4f87b8f..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore):
                 self._prune_old_outbound_device_pokes, 60 * 60 * 1000
             )
 
+    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+        """Retrieve number of all devices of given users.
+        Only returns number of devices that are not marked as hidden.
+
+        Args:
+            user_ids: The IDs of the users which owns devices
+        Returns:
+            Number of devices of this users.
+        """
+
+        def count_devices_by_users_txn(txn, user_ids):
+            sql = """
+                SELECT count(*)
+                FROM devices
+                WHERE
+                    hidden = '0' AND
+            """
+
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "user_id", user_ids
+            )
+
+            txn.execute(sql + clause, args)
+            return txn.fetchone()[0]
+
+        if not user_ids:
+            return 0
+
+        return await self.db_pool.runInteraction(
+            "count_devices_by_users", count_devices_by_users_txn, user_ids
+        )
+
     async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
@@ -865,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 DELETE FROM device_lists_outbound_last_success
                 WHERE destination = ? AND user_id = ?
             """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+            txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
 
             logger.info("Pruned %d device list outbound pokes", count)
 
@@ -1311,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
-        txn.executemany(
+        txn.execute_batch(
             """
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4d1b92d1aa..309f1e865b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         for user_chunk in batch_iter(user_ids, 100):
             clause, params = make_in_list_sql_clause(
-                txn.database_engine, "k.user_id", user_chunk
-            )
-            sql = (
-                """
-                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
-                  FROM e2e_cross_signing_keys k
-                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
-                                FROM e2e_cross_signing_keys
-                               GROUP BY user_id, keytype) s
-                 USING (user_id, stream_id, keytype)
-                 WHERE
-            """
-                + clause
+                txn.database_engine, "user_id", user_chunk
             )
 
+            # Fetch the latest key for each type per user.
+            if isinstance(self.database_engine, PostgresEngine):
+                # The `DISTINCT ON` clause will pick the *first* row it
+                # encounters, so ordering by stream ID desc will ensure we get
+                # the latest key.
+                sql = """
+                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        ORDER BY user_id, keytype, stream_id DESC
+                """ % {
+                    "clause": clause
+                }
+            else:
+                # SQLite has special handling for bare columns when using
+                # MIN/MAX with a `GROUP BY` clause where it picks the value from
+                # a row that matches the MIN/MAX.
+                sql = """
+                    SELECT user_id, keytype, keydata, MAX(stream_id)
+                        FROM e2e_cross_signing_keys
+                        WHERE %(clause)s
+                        GROUP BY user_id, keytype
+                """ % {
+                    "clause": clause
+                }
+
             txn.execute(sql, params)
             rows = self.db_pool.cursor_to_dict(txn)
 
@@ -619,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, dict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -707,53 +722,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         """Get the current stream id from the _device_list_id_gen"""
         ...
 
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
-    async def set_e2e_device_keys(
-        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
-    ) -> bool:
-        """Stores device keys for a device. Returns whether there was a change
-        or the keys were already in the database.
-        """
-
-        def _set_e2e_device_keys_txn(txn):
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
-
-            old_key_json = self.db_pool.simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
-
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
-
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
-
-        return await self.db_pool.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
-        )
-
     async def claim_e2e_one_time_keys(
         self, query_list: Iterable[Tuple[str, str, str]]
-    ) -> Dict[str, Dict[str, Dict[str, bytes]]]:
+    ) -> Dict[str, Dict[str, Dict[str, str]]]:
         """Take a list of one time keys out of the database.
 
         Args:
@@ -840,6 +811,50 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    async def set_e2e_device_keys(
+        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
+    ) -> bool:
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
+        """
+
+        def _set_e2e_device_keys_txn(txn):
+            set_tag("user_id", user_id)
+            set_tag("device_id", device_id)
+            set_tag("time_now", time_now)
+            set_tag("device_keys", device_keys)
+
+            old_key_json = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                retcol="key_json",
+                allow_none=True,
+            )
+
+            # In py3 we need old_key_json to match new_key_json type. The DB
+            # returns unicode while encode_canonical_json returns bytes.
+            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+            if old_key_json == new_key_json:
+                log_kv({"Message": "Device key already stored."})
+                return False
+
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                values={"ts_added_ms": time_now, "key_json": new_key_json},
+            )
+            log_kv({"message": "Device keys stored."})
+            return True
+
+        return await self.db_pool.runInteraction(
+            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        )
+
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn):
             log_kv(
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 2e07c37340..8326640d20 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Cursor
 from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
 logger = logging.getLogger(__name__)
 
 
+class _NoChainCoverIndex(Exception):
+    def __init__(self, room_id: str):
+        super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
+
+
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
@@ -137,7 +144,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
-    async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
+    async def get_auth_chain_difference(
+        self, room_id: str, state_sets: List[Set[str]]
+    ) -> Set[str]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
 
@@ -149,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             The set of the difference in auth chains.
         """
 
+        # Check if we have indexed the room so we can use the chain cover
+        # algorithm.
+        room = await self.get_room(room_id)
+        if room["has_auth_chain_index"]:
+            try:
+                return await self.db_pool.runInteraction(
+                    "get_auth_chain_difference_chains",
+                    self._get_auth_chain_difference_using_cover_index_txn,
+                    room_id,
+                    state_sets,
+                )
+            except _NoChainCoverIndex:
+                # For whatever reason we don't actually have a chain cover index
+                # for the events in question, so we fall back to the old method.
+                pass
+
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
             self._get_auth_chain_difference_txn,
             state_sets,
         )
 
+    def _get_auth_chain_difference_using_cover_index_txn(
+        self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+    ) -> Set[str]:
+        """Calculates the auth chain difference using the chain index.
+
+        See docs/auth_chain_difference_algorithm.md for details
+        """
+
+        # First we look up the chain ID/sequence numbers for all the events, and
+        # work out the chain/sequence numbers reachable from each state set.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Map from event_id -> (chain ID, seq no)
+        chain_info = {}  # type: Dict[str, Tuple[int, int]]
+
+        # Map from chain ID -> seq no -> event Id
+        chain_to_event = {}  # type: Dict[int, Dict[int, str]]
+
+        # All the chains that we've found that are reachable from the state
+        # sets.
+        seen_chains = set()  # type: Set[int]
+
+        sql = """
+            SELECT event_id, chain_id, sequence_number
+            FROM event_auth_chains
+            WHERE %s
+        """
+        for batch in batch_iter(initial_events, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for event_id, chain_id, sequence_number in txn:
+                chain_info[event_id] = (chain_id, sequence_number)
+                seen_chains.add(chain_id)
+                chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+        # Check that we actually have a chain ID for all the events.
+        events_missing_chain_info = initial_events.difference(chain_info)
+        if events_missing_chain_info:
+            # This can happen due to e.g. downgrade/upgrade of the server. We
+            # raise an exception and fall back to the previous algorithm.
+            logger.info(
+                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+                room_id,
+                events_missing_chain_info,
+            )
+            raise _NoChainCoverIndex(room_id)
+
+        # Corresponds to `state_sets`, except as a map from chain ID to max
+        # sequence number reachable from the state set.
+        set_to_chain = []  # type: List[Dict[int, int]]
+        for state_set in state_sets:
+            chains = {}  # type: Dict[int, int]
+            set_to_chain.append(chains)
+
+            for event_id in state_set:
+                chain_id, seq_no = chain_info[event_id]
+
+                chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
+
+        # Now we look up all links for the chains we have, adding chains to
+        # set_to_chain that are reachable from each set.
+        sql = """
+            SELECT
+                origin_chain_id, origin_sequence_number,
+                target_chain_id, target_sequence_number
+            FROM event_auth_chain_links
+            WHERE %s
+        """
+
+        # (We need to take a copy of `seen_chains` as we want to mutate it in
+        # the loop)
+        for batch in batch_iter(set(seen_chains), 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "origin_chain_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for (
+                origin_chain_id,
+                origin_sequence_number,
+                target_chain_id,
+                target_sequence_number,
+            ) in txn:
+                for chains in set_to_chain:
+                    # chains are only reachable if the origin sequence number of
+                    # the link is less than the max sequence number in the
+                    # origin chain.
+                    if origin_sequence_number <= chains.get(origin_chain_id, 0):
+                        chains[target_chain_id] = max(
+                            target_sequence_number, chains.get(target_chain_id, 0),
+                        )
+
+                seen_chains.add(target_chain_id)
+
+        # Now for each chain we figure out the maximum sequence number reachable
+        # from *any* state set and the minimum sequence number reachable from
+        # *all* state sets. Events in that range are in the auth chain
+        # difference.
+        result = set()
+
+        # Mapping from chain ID to the range of sequence numbers that should be
+        # pulled from the database.
+        chain_to_gap = {}  # type: Dict[int, Tuple[int, int]]
+
+        for chain_id in seen_chains:
+            min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
+            max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+
+            if min_seq_no < max_seq_no:
+                # We have a non empty gap, try and fill it from the events that
+                # we have, otherwise add them to the list of gaps to pull out
+                # from the DB.
+                for seq_no in range(min_seq_no + 1, max_seq_no + 1):
+                    event_id = chain_to_event.get(chain_id, {}).get(seq_no)
+                    if event_id:
+                        result.add(event_id)
+                    else:
+                        chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
+                        break
+
+        if not chain_to_gap:
+            # If there are no gaps to fetch, we're done!
+            return result
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND min_seq < sequence_number AND sequence_number <= max_seq
+            """
+
+            args = [
+                (chain_id, min_no, max_no)
+                for chain_id, (min_no, max_no) in chain_to_gap.items()
+            ]
+
+            rows = txn.execute_values(sql, args)
+            result.update(r for r, in rows)
+        else:
+            # For SQLite we just fall back to doing a noddy for loop.
+            sql = """
+                SELECT event_id FROM event_auth_chains
+                WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
+            """
+            for chain_id, (min_no, max_no) in chain_to_gap.items():
+                txn.execute(sql, (chain_id, min_no, max_no))
+                result.update(r for r, in txn)
+
+        return result
+
     def _get_auth_chain_difference_txn(
         self, txn, state_sets: List[Set[str]]
     ) -> Set[str]:
+        """Calculates the auth chain difference using a breadth first search.
+
+        This is used when we don't have a cover index for the room.
+        """
 
         # Algorithm Description
         # ~~~~~~~~~~~~~~~~~~~~~
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 2e56dfaf31..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 VALUES (?, ?, ?, ?, ?, ?)
             """
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     _gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             ],
         )
 
-        txn.executemany(
+        txn.execute_batch(
             """
                 UPDATE event_push_summary
                 SET notif_count = ?, unread_count = ?, stream_ordering = ?
@@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             (rotate_to_stream_ordering,),
         )
 
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
+        """
+        Purges old push actions for a user and room before a given
+        stream_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
+        Args:
+            txn: The transcation
+            room_id: Room ID to delete from
+            user_id: user ID to delete for
+            stream_ordering: The lowest stream ordering which will
+                                  not be deleted.
+        """
+        txn.call_after(
+            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id, user_id),
+        )
+
+        # We need to join on the events table to get the received_ts for
+        # event_push_actions and sqlite won't let us use a join in a delete so
+        # we can't just delete where received_ts < x. Furthermore we can
+        # only identify event_push_actions by a tuple of room_id, event_id
+        # we we can't use a subquery.
+        # Instead, we look up the stream ordering for the last event in that
+        # room received before the threshold time and delete event_push_actions
+        # in the room with a stream_odering before that.
+        txn.execute(
+            "DELETE FROM event_push_actions "
+            " WHERE user_id = ? AND room_id = ? AND "
+            " stream_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        )
+
+        txn.execute(
+            """
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """,
+            (room_id, user_id, stream_ordering),
+        )
+
 
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@@ -894,62 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
 
-    async def get_latest_push_action_stream_ordering(self):
-        def f(txn):
-            txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
-            return txn.fetchone()
-
-        result = await self.db_pool.runInteraction(
-            "get_latest_push_action_stream_ordering", f
-        )
-        return result[0] or 0
-
-    def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transcation
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
-
 
 def _action_has_highlight(actions):
     for action in actions:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 90fb1a1f00..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,17 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -35,7 +45,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
-from synapse.util.iterutils import batch_iter
+from synapse.util.iterutils import batch_iter, sorted_topologically
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -366,6 +376,36 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
+        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn, events_and_contexts=events_and_contexts
+        )
+
+        # From this point onwards the events are only ones that weren't
+        # rejected.
+
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+            backfilled=backfilled,
+        )
+
+        # We call this last as it assumes we've inserted the events into
+        # room_memberships, where applicable.
+        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+    def _persist_event_auth_chain_txn(
+        self, txn: LoggingTransaction, events: List[EventBase],
+    ) -> None:
+
+        # We only care about state events, so this if there are no state events.
+        if not any(e.is_state() for e in events):
+            return
+
         # We want to store event_auth mappings for rejected events, as they're
         # used in state res v2.
         # This is only necessary if the rejected event appears in an accepted
@@ -381,31 +421,467 @@ class PersistEventsStore:
                     "room_id": event.room_id,
                     "auth_id": auth_id,
                 }
-                for event, _ in events_and_contexts
+                for event in events
                 for auth_id in event.auth_event_ids()
                 if event.is_state()
             ],
         )
 
-        # _store_rejected_events_txn filters out any events which were
-        # rejected, and returns the filtered list.
-        events_and_contexts = self._store_rejected_events_txn(
-            txn, events_and_contexts=events_and_contexts
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
+
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="rooms",
+            column="room_id",
+            iterable={event.room_id for event in events if event.is_state()},
+            keyvalues={},
+            retcols=("room_id", "has_auth_chain_index"),
         )
+        rooms_using_chain_index = {
+            row["room_id"] for row in rows if row["has_auth_chain_index"]
+        }
 
-        # From this point onwards the events are only ones that weren't
-        # rejected.
+        state_events = {
+            event.event_id: event
+            for event in events
+            if event.is_state() and event.room_id in rooms_using_chain_index
+        }
 
-        self._update_metadata_tables_txn(
+        if not state_events:
+            return
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {
+            e.event_id: (e.type, e.state_key) for e in state_events.values()
+        }
+        event_to_auth_chain = {
+            e.event_id: e.auth_event_ids() for e in state_events.values()
+        }
+        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+        self._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+    @classmethod
+    def _add_chain_cover_index(
+        cls,
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+    ) -> None:
+        """Calculate the chain cover index for the given events.
+
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
+
+        # Map from event ID to chain ID/sequence number.
+        chain_map = {}  # type: Dict[str, Tuple[int, int]]
+
+        # Set of event IDs to calculate chain ID/seq numbers for.
+        events_to_calc_chain_id_for = set(event_to_room_id)
+
+        # We check if there are any events that need to be handled in the rooms
+        # we're looking at. These should just be out of band memberships, where
+        # we didn't have the auth chain when we first persisted.
+        rows = db_pool.simple_select_many_txn(
             txn,
-            events_and_contexts=events_and_contexts,
-            all_events_and_contexts=all_events_and_contexts,
-            backfilled=backfilled,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="room_id",
+            iterable=set(event_to_room_id.values()),
+            retcols=("event_id", "type", "state_key"),
         )
+        for row in rows:
+            event_id = row["event_id"]
+            event_type = row["type"]
+            state_key = row["state_key"]
+
+            # (We could pull out the auth events for all rows at once using
+            # simple_select_many, but this case happens rarely and almost always
+            # with a single row.)
+            auth_events = db_pool.simple_select_onecol_txn(
+                txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+            )
 
-        # We call this last as it assumes we've inserted the events into
-        # room_memberships, where applicable.
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+            events_to_calc_chain_id_for.add(event_id)
+            event_to_types[event_id] = (event_type, state_key)
+            event_to_auth_chain[event_id] = auth_events
+
+        # First we get the chain ID and sequence numbers for the events'
+        # auth events (that aren't also currently being persisted).
+        #
+        # Note that there there is an edge case here where we might not have
+        # calculated chains and sequence numbers for events that were "out
+        # of band". We handle this case by fetching the necessary info and
+        # adding it to the set of events to calculate chain IDs for.
+
+        missing_auth_chains = {
+            a_id
+            for auth_events in event_to_auth_chain.values()
+            for a_id in auth_events
+            if a_id not in events_to_calc_chain_id_for
+        }
+
+        # We loop here in case we find an out of band membership and need to
+        # fetch their auth event info.
+        while missing_auth_chains:
+            sql = """
+                SELECT event_id, events.type, state_key, chain_id, sequence_number
+                FROM events
+                INNER JOIN state_events USING (event_id)
+                LEFT JOIN event_auth_chains USING (event_id)
+                WHERE
+            """
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", missing_auth_chains,
+            )
+            txn.execute(sql + clause, args)
+
+            missing_auth_chains.clear()
+
+            for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+                event_to_types[auth_id] = (event_type, state_key)
+
+                if chain_id is None:
+                    # No chain ID, so the event was persisted out of band.
+                    # We add to list of events to calculate auth chains for.
+
+                    events_to_calc_chain_id_for.add(auth_id)
+
+                    event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
+                        txn,
+                        "event_auth",
+                        keyvalues={"event_id": auth_id},
+                        retcol="auth_id",
+                    )
+
+                    missing_auth_chains.update(
+                        e
+                        for e in event_to_auth_chain[auth_id]
+                        if e not in event_to_types
+                    )
+                else:
+                    chain_map[auth_id] = (chain_id, sequence_number)
+
+        # Now we check if we have any events where we don't have auth chain,
+        # this should only be out of band memberships.
+        for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
+            for auth_id in event_to_auth_chain[event_id]:
+                if (
+                    auth_id not in chain_map
+                    and auth_id not in events_to_calc_chain_id_for
+                ):
+                    events_to_calc_chain_id_for.discard(event_id)
+
+                    # If this is an event we're trying to persist we add it to
+                    # the list of events to calculate chain IDs for next time
+                    # around. (Otherwise we will have already added it to the
+                    # table).
+                    room_id = event_to_room_id.get(event_id)
+                    if room_id:
+                        e_type, state_key = event_to_types[event_id]
+                        db_pool.simple_insert_txn(
+                            txn,
+                            table="event_auth_chain_to_calculate",
+                            values={
+                                "event_id": event_id,
+                                "room_id": room_id,
+                                "type": e_type,
+                                "state_key": state_key,
+                            },
+                        )
+
+                    # We stop checking the event's auth events since we've
+                    # discarded it.
+                    break
+
+        if not events_to_calc_chain_id_for:
+            return
+
+        # Allocate chain ID/sequence numbers to each new event.
+        new_chain_tuples = cls._allocate_chain_ids(
+            txn,
+            db_pool,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+            events_to_calc_chain_id_for,
+            chain_map,
+        )
+        chain_map.update(new_chain_tuples)
+
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            values=[
+                {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+                for event_id, (c_id, seq) in new_chain_tuples.items()
+            ],
+        )
+
+        db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            iterable=new_chain_tuples,
+        )
+
+        # Now we need to calculate any new links between chains caused by
+        # the new events.
+        #
+        # Links are pairs of chain ID/sequence numbers such that for any
+        # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
+        # if and only if there is at least one link (CA, S1) -> (CB, S2)
+        # where SA >= S1 and S2 >= SB.
+        #
+        # We try and avoid adding redundant links to the table, e.g. if we
+        # have two links between two chains which both start/end at the
+        # sequence number event (or cross) then one can be safely dropped.
+        #
+        # To calculate new links we look at every new event and:
+        #   1. Fetch the chain ID/sequence numbers of its auth events,
+        #      discarding any that are reachable by other auth events, or
+        #      that have the same chain ID as the event.
+        #   2. For each retained auth event we:
+        #       a. Add a link from the event's to the auth event's chain
+        #          ID/sequence number; and
+        #       b. Add a link from the event to every chain reachable by the
+        #          auth event.
+
+        # Step 1, fetch all existing links from all the chains we've seen
+        # referenced.
+        chain_links = _LinkMap()
+        rows = db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            column="origin_chain_id",
+            iterable={chain_id for chain_id, _ in chain_map.values()},
+            keyvalues={},
+            retcols=(
+                "origin_chain_id",
+                "origin_sequence_number",
+                "target_chain_id",
+                "target_sequence_number",
+            ),
+        )
+        for row in rows:
+            chain_links.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+                new=False,
+            )
+
+        # We do this in toplogical order to avoid adding redundant links.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            chain_id, sequence_number = chain_map[event_id]
+
+            # Filter out auth events that are reachable by other auth
+            # events. We do this by looking at every permutation of pairs of
+            # auth events (A, B) to check if B is reachable from A.
+            reduction = {
+                a_id
+                for a_id in event_to_auth_chain.get(event_id, [])
+                if chain_map[a_id][0] != chain_id
+            }
+            for start_auth_id, end_auth_id in itertools.permutations(
+                event_to_auth_chain.get(event_id, []), r=2,
+            ):
+                if chain_links.exists_path_from(
+                    chain_map[start_auth_id], chain_map[end_auth_id]
+                ):
+                    reduction.discard(end_auth_id)
+
+            # Step 2, figure out what the new links are from the reduced
+            # list of auth events.
+            for auth_id in reduction:
+                auth_chain_id, auth_sequence_number = chain_map[auth_id]
+
+                # Step 2a, add link between the event and auth event
+                chain_links.add_link(
+                    (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
+                )
+
+                # Step 2b, add a link to chains reachable from the auth
+                # event.
+                for target_id, target_seq in chain_links.get_links_from(
+                    (auth_chain_id, auth_sequence_number)
+                ):
+                    if target_id == chain_id:
+                        continue
+
+                    chain_links.add_link(
+                        (chain_id, sequence_number), (target_id, target_seq)
+                    )
+
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chain_links",
+            values=[
+                {
+                    "origin_chain_id": source_id,
+                    "origin_sequence_number": source_seq,
+                    "target_chain_id": target_id,
+                    "target_sequence_number": target_seq,
+                }
+                for (
+                    source_id,
+                    source_seq,
+                    target_id,
+                    target_seq,
+                ) in chain_links.get_additions()
+            ],
+        )
+
+    @staticmethod
+    def _allocate_chain_ids(
+        txn,
+        db_pool: DatabasePool,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+        events_to_calc_chain_id_for: Set[str],
+        chain_map: Dict[str, Tuple[int, int]],
+    ) -> Dict[str, Tuple[int, int]]:
+        """Allocates, but does not persist, chain ID/sequence numbers for the
+        events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+        for info on args)
+        """
+
+        # We now calculate the chain IDs/sequence numbers for the events. We do
+        # this by looking at the chain ID and sequence number of any auth event
+        # with the same type/state_key and incrementing the sequence number by
+        # one. If there was no match or the chain ID/sequence number is already
+        # taken we generate a new chain.
+        #
+        # We try to reduce the number of times that we hit the database by
+        # batching up calls, to make this more efficient when persisting large
+        # numbers of state events (e.g. during joins).
+        #
+        # We do this by:
+        #   1. Calculating for each event which auth event will be used to
+        #      inherit the chain ID, i.e. converting the auth chain graph to a
+        #      tree that we can allocate chains on. We also keep track of which
+        #      existing chain IDs have been referenced.
+        #   2. Fetching the max allocated sequence number for each referenced
+        #      existing chain ID, generating a map from chain ID to the max
+        #      allocated sequence number.
+        #   3. Iterating over the tree and allocating a chain ID/seq no. to the
+        #      new event, by incrementing the sequence number from the
+        #      referenced event's chain ID/seq no. and checking that the
+        #      incremented sequence number hasn't already been allocated (by
+        #      looking in the map generated in the previous step). We generate a
+        #      new chain if the sequence number has already been allocated.
+        #
+
+        existing_chains = set()  # type: Set[int]
+        tree = []  # type: List[Tuple[str, Optional[str]]]
+
+        # We need to do this in a topologically sorted order as we want to
+        # generate chain IDs/sequence numbers of an event's auth events before
+        # the event itself.
+        for event_id in sorted_topologically(
+            events_to_calc_chain_id_for, event_to_auth_chain
+        ):
+            for auth_id in event_to_auth_chain.get(event_id, []):
+                if event_to_types.get(event_id) == event_to_types.get(auth_id):
+                    existing_chain_id = chain_map.get(auth_id)
+                    if existing_chain_id:
+                        existing_chains.add(existing_chain_id[0])
+
+                    tree.append((event_id, auth_id))
+                    break
+            else:
+                tree.append((event_id, None))
+
+        # Fetch the current max sequence number for each existing referenced chain.
+        sql = """
+            SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+            WHERE %s
+            GROUP BY chain_id
+        """
+        clause, args = make_in_list_sql_clause(
+            db_pool.engine, "chain_id", existing_chains
+        )
+        txn.execute(sql % (clause,), args)
+
+        chain_to_max_seq_no = {row[0]: row[1] for row in txn}  # type: Dict[Any, int]
+
+        # Allocate the new events chain ID/sequence numbers.
+        #
+        # To reduce the number of calls to the database we don't allocate a
+        # chain ID number in the loop, instead we use a temporary `object()` for
+        # each new chain ID. Once we've done the loop we generate the necessary
+        # number of new chain IDs in one call, replacing all temporary
+        # objects with real allocated chain IDs.
+
+        unallocated_chain_ids = set()  # type: Set[object]
+        new_chain_tuples = {}  # type: Dict[str, Tuple[Any, int]]
+        for event_id, auth_event_id in tree:
+            # If we reference an auth_event_id we fetch the allocated chain ID,
+            # either from the existing `chain_map` or the newly generated
+            # `new_chain_tuples` map.
+            existing_chain_id = None
+            if auth_event_id:
+                existing_chain_id = new_chain_tuples.get(auth_event_id)
+                if not existing_chain_id:
+                    existing_chain_id = chain_map[auth_event_id]
+
+            new_chain_tuple = None  # type: Optional[Tuple[Any, int]]
+            if existing_chain_id:
+                # We found a chain ID/sequence number candidate, check its
+                # not already taken.
+                proposed_new_id = existing_chain_id[0]
+                proposed_new_seq = existing_chain_id[1] + 1
+
+                if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+                    new_chain_tuple = (
+                        proposed_new_id,
+                        proposed_new_seq,
+                    )
+
+            # If we need to start a new chain we allocate a temporary chain ID.
+            if not new_chain_tuple:
+                new_chain_tuple = (object(), 1)
+                unallocated_chain_ids.add(new_chain_tuple[0])
+
+            new_chain_tuples[event_id] = new_chain_tuple
+            chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+        # Generate new chain IDs for all unallocated chain IDs.
+        newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+            txn, len(unallocated_chain_ids)
+        )
+
+        # Map from potentially temporary chain ID to real chain ID
+        chain_id_to_allocated_map = dict(
+            zip(unallocated_chain_ids, newly_allocated_chain_ids)
+        )  # type: Dict[Any, int]
+        chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+        return {
+            event_id: (chain_id_to_allocated_map[chain_id], seq)
+            for event_id, (chain_id, seq) in new_chain_tuples.items()
+        }
 
     def _persist_transaction_ids_txn(
         self,
@@ -489,7 +965,7 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.executemany(
+                txn.execute_batch(
                     sql,
                     (
                         (
@@ -508,7 +984,7 @@ class PersistEventsStore:
                 )
                 # Now we actually update the current_state_events table
 
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM current_state_events"
                     " WHERE room_id = ? AND type = ? AND state_key = ?",
                     (
@@ -520,7 +996,7 @@ class PersistEventsStore:
                 # We include the membership in the current state table, hence we do
                 # a lookup when we insert. This assumes that all events have already
                 # been inserted into room_memberships.
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership)
                     VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -540,7 +1016,7 @@ class PersistEventsStore:
             # we have no record of the fact the user *was* a member of the
             # room but got, say, state reset out of it.
             if to_delete or to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     "DELETE FROM local_current_membership"
                     " WHERE room_id = ? AND user_id = ?",
                     (
@@ -551,7 +1027,7 @@ class PersistEventsStore:
                 )
 
             if to_insert:
-                txn.executemany(
+                txn.execute_batch(
                     """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership)
                     VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -799,7 +1275,8 @@ class PersistEventsStore:
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
     def _store_event_txn(self, txn, events_and_contexts):
-        """Insert new events into the event and event_json tables
+        """Insert new events into the event, event_json, redaction and
+        state_events tables.
 
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
@@ -871,6 +1348,29 @@ class PersistEventsStore:
                     updatevalues={"have_censored": False},
                 )
 
+        state_events_and_contexts = [
+            ec for ec in events_and_contexts if ec[0].is_state()
+        ]
+
+        state_values = []
+        for event, context in state_events_and_contexts:
+            vals = {
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "type": event.type,
+                "state_key": event.state_key,
+            }
+
+            # TODO: How does this work with backfilling?
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
+
+            state_values.append(vals)
+
+        self.db_pool.simple_insert_many_txn(
+            txn, table="state_events", values=state_values
+        )
+
     def _store_rejected_events_txn(self, txn, events_and_contexts):
         """Add rows to the 'rejections' table for received events which were
         rejected
@@ -987,29 +1487,6 @@ class PersistEventsStore:
             txn, [event for event, _ in events_and_contexts]
         )
 
-        state_events_and_contexts = [
-            ec for ec in events_and_contexts if ec[0].is_state()
-        ]
-
-        state_values = []
-        for event, context in state_events_and_contexts:
-            vals = {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "type": event.type,
-                "state_key": event.state_key,
-            }
-
-            # TODO: How does this work with backfilling?
-            if hasattr(event, "replaces_state"):
-                vals["prev_state"] = event.replaces_state
-
-            state_values.append(vals)
-
-        self.db_pool.simple_insert_many_txn(
-            txn, table="state_events", values=state_values
-        )
-
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
@@ -1350,7 +1827,7 @@ class PersistEventsStore:
         """
 
         if events_and_contexts:
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (
@@ -1379,7 +1856,7 @@ class PersistEventsStore:
 
         # Now we delete the staging area for *all* events that were being
         # persisted.
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM event_push_actions_staging WHERE event_id = ?",
             ((event.event_id,) for event, _ in all_events_and_contexts),
         )
@@ -1498,7 +1975,7 @@ class PersistEventsStore:
             " )"
         )
 
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1512,7 +1989,7 @@ class PersistEventsStore:
             "DELETE FROM event_backward_extremities"
             " WHERE event_id = ? AND room_id = ?"
         )
-        txn.executemany(
+        txn.execute_batch(
             query,
             [
                 (ev.event_id, ev.room_id)
@@ -1520,3 +1997,131 @@ class PersistEventsStore:
                 if not ev.internal_metadata.is_outlier()
             ],
         )
+
+
+@attr.s(slots=True)
+class _LinkMap:
+    """A helper type for tracking links between chains.
+    """
+
+    # Stores the set of links as nested maps: source chain ID -> target chain ID
+    # -> source sequence number -> target sequence number.
+    maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+
+    # Stores the links that have been added (with new set to true), as tuples of
+    # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
+    additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+
+    def add_link(
+        self,
+        src_tuple: Tuple[int, int],
+        target_tuple: Tuple[int, int],
+        new: bool = True,
+    ) -> bool:
+        """Add a new link between two chains, ensuring no redundant links are added.
+
+        New links should be added in topological order.
+
+        Args:
+            src_tuple: The chain ID/sequence number of the source of the link.
+            target_tuple: The chain ID/sequence number of the target of the link.
+            new: Whether this is a "new" link, i.e. should it be returned
+                by `get_additions`.
+
+        Returns:
+            True if a link was added, false if the given link was dropped as redundant
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
+
+        assert src_chain != target_chain
+
+        if new:
+            # Check if the new link is redundant
+            for current_seq_src, current_seq_target in current_links.items():
+                # If a link "crosses" another link then its redundant. For example
+                # in the following link 1 (L1) is redundant, as any event reachable
+                # via L1 is *also* reachable via L2.
+                #
+                #   Chain A     Chain B
+                #      |          |
+                #   L1 |------    |
+                #      |     |    |
+                #   L2 |---- | -->|
+                #      |     |    |
+                #      |     |--->|
+                #      |          |
+                #      |          |
+                #
+                # So we only need to keep links which *do not* cross, i.e. links
+                # that both start and end above or below an existing link.
+                #
+                # Note, since we add links in topological ordering we should never
+                # see `src_seq` less than `current_seq_src`.
+
+                if current_seq_src <= src_seq and target_seq <= current_seq_target:
+                    # This new link is redundant, nothing to do.
+                    return False
+
+            self.additions.add((src_chain, src_seq, target_chain, target_seq))
+
+        current_links[src_seq] = target_seq
+        return True
+
+    def get_links_from(
+        self, src_tuple: Tuple[int, int]
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the chains reachable from the given chain/sequence number.
+
+        Yields:
+            The chain ID and sequence number the link points to.
+        """
+        src_chain, src_seq = src_tuple
+        for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
+            for link_src_seq, target_seq in sequence_numbers.items():
+                if link_src_seq <= src_seq:
+                    yield target_id, target_seq
+
+    def get_links_between(
+        self, source_chain: int, target_chain: int
+    ) -> Generator[Tuple[int, int], None, None]:
+        """Gets the links between two chains.
+
+        Yields:
+            The source and target sequence numbers.
+        """
+
+        yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
+
+    def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
+        """Gets any newly added links.
+
+        Yields:
+            The source chain ID/sequence number and target chain ID/sequence number
+        """
+
+        for src_chain, src_seq, target_chain, _ in self.additions:
+            target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
+            if target_seq is not None:
+                yield (src_chain, src_seq, target_chain, target_seq)
+
+    def exists_path_from(
+        self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+    ) -> bool:
+        """Checks if there is a path between the source chain ID/sequence and
+        target chain ID/sequence.
+        """
+        src_chain, src_seq = src_tuple
+        target_chain, target_seq = target_tuple
+
+        if src_chain == target_chain:
+            return target_seq <= src_seq
+
+        links = self.get_links_between(src_chain, target_chain)
+        for link_start_seq, link_end_seq in links:
+            if link_start_seq <= src_seq and target_seq <= link_end_seq:
+                return True
+
+        return False
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 97b6754846..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,14 +14,41 @@
 # limitations under the License.
 
 import logging
+from typing import Dict, List, Optional, Tuple
+
+import attr
 
 from synapse.api.constants import EventContentFields
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.databases.main.events import PersistEventsStore
+from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+    """Return value for _calculate_chain_cover_txn.
+    """
+
+    # The last room_id/depth/stream processed.
+    room_id = attr.ib(type=str)
+    depth = attr.ib(type=int)
+    stream = attr.ib(type=int)
+
+    # Number of rows processed
+    processed_count = attr.ib(type=int)
+
+    # Map from room_id to last depth/stream processed for each room that we have
+    # processed all events for (i.e. the rooms we can flip the
+    # `has_auth_chain_index` for)
+    finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
 class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -99,13 +126,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             columns=["user_id", "created_ts"],
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "rejected_events_metadata", self._rejected_events_metadata,
+        )
+
+        self.db_pool.updates.register_background_update_handler(
+            "chain_cover", self._chain_cover_index,
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
@@ -143,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
 
-            for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
-                clump = update_rows[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, update_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -175,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
-
         def reindex_search_txn(txn):
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
@@ -221,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
 
-            for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
-                clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            txn.execute_batch(sql, rows_to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -582,3 +609,314 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             await self.db_pool.updates._end_background_update("event_store_labels")
 
         return num_rows
+
+    async def _rejected_events_metadata(self, progress: dict, batch_size: int) -> int:
+        """Adds rejected events to the `state_events` and `event_auth` metadata
+        tables.
+        """
+
+        last_event_id = progress.get("last_event_id", "")
+
+        def get_rejected_events(
+            txn: Cursor,
+        ) -> List[Tuple[str, str, JsonDict, bool, bool]]:
+            # Fetch rejected event json, their room version and whether we have
+            # inserted them into the state_events or auth_events tables.
+            #
+            # Note we can assume that events that don't have a corresponding
+            # room version are V1 rooms.
+            sql = """
+                SELECT DISTINCT
+                    event_id,
+                    COALESCE(room_version, '1'),
+                    json,
+                    state_events.event_id IS NOT NULL,
+                    event_auth.event_id IS NOT NULL
+                FROM rejections
+                INNER JOIN event_json USING (event_id)
+                LEFT JOIN rooms USING (room_id)
+                LEFT JOIN state_events USING (event_id)
+                LEFT JOIN event_auth USING (event_id)
+                WHERE event_id > ?
+                ORDER BY event_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_event_id, batch_size,))
+
+            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
+
+        results = await self.db_pool.runInteraction(
+            desc="_rejected_events_metadata_get", func=get_rejected_events
+        )
+
+        if not results:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+            return 0
+
+        state_events = []
+        auth_events = []
+        for event_id, room_version, event_json, has_state, has_event_auth in results:
+            last_event_id = event_id
+
+            if has_state and has_event_auth:
+                continue
+
+            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version)
+            if not room_version_obj:
+                # We no longer support this room version, so we just ignore the
+                # events entirely.
+                logger.info(
+                    "Ignoring event with unknown room version %r: %r",
+                    room_version,
+                    event_id,
+                )
+                continue
+
+            event = make_event_from_dict(event_json, room_version_obj)
+
+            if not event.is_state():
+                continue
+
+            if not has_state:
+                state_events.append(
+                    {
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    }
+                )
+
+            if not has_event_auth:
+                for auth_id in event.auth_event_ids():
+                    auth_events.append(
+                        {
+                            "room_id": event.room_id,
+                            "event_id": event.event_id,
+                            "auth_id": auth_id,
+                        }
+                    )
+
+        if state_events:
+            await self.db_pool.simple_insert_many(
+                table="state_events",
+                values=state_events,
+                desc="_rejected_events_metadata_state_events",
+            )
+
+        if auth_events:
+            await self.db_pool.simple_insert_many(
+                table="event_auth",
+                values=auth_events,
+                desc="_rejected_events_metadata_event_auth",
+            )
+
+        await self.db_pool.updates._background_update_progress(
+            "rejected_events_metadata", {"last_event_id": last_event_id}
+        )
+
+        if len(results) < batch_size:
+            await self.db_pool.updates._end_background_update(
+                "rejected_events_metadata"
+            )
+
+        return len(results)
+
+    async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
+        """A background updates that iterates over all rooms and generates the
+        chain cover index for them.
+        """
+
+        current_room_id = progress.get("current_room_id", "")
+
+        # Where we've processed up to in the room, defaults to the start of the
+        # room.
+        last_depth = progress.get("last_depth", -1)
+        last_stream = progress.get("last_stream", -1)
+
+        result = await self.db_pool.runInteraction(
+            "_chain_cover_index",
+            self._calculate_chain_cover_txn,
+            current_room_id,
+            last_depth,
+            last_stream,
+            batch_size,
+            single_room=False,
+        )
+
+        finished = result.processed_count == 0
+
+        total_rows_processed = result.processed_count
+        current_room_id = result.room_id
+        last_depth = result.depth
+        last_stream = result.stream
+
+        for room_id, (depth, stream) in result.finished_room_map.items():
+            # If we've done all the events in the room we flip the
+            # `has_auth_chain_index` in the DB. Note that its possible for
+            # further events to be persisted between the above and setting the
+            # flag without having the chain cover calculated for them. This is
+            # fine as a) the code gracefully handles these cases and b) we'll
+            # calculate them below.
+
+            await self.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
+                desc="_chain_cover_index",
+            )
+
+            # Handle any events that might have raced with us flipping the
+            # bit above.
+            result = await self.db_pool.runInteraction(
+                "_chain_cover_index",
+                self._calculate_chain_cover_txn,
+                room_id,
+                depth,
+                stream,
+                batch_size=None,
+                single_room=True,
+            )
+
+            total_rows_processed += result.processed_count
+
+        if finished:
+            await self.db_pool.updates._end_background_update("chain_cover")
+            return total_rows_processed
+
+        await self.db_pool.updates._background_update_progress(
+            "chain_cover",
+            {
+                "current_room_id": current_room_id,
+                "last_depth": last_depth,
+                "last_stream": last_stream,
+            },
+        )
+
+        return total_rows_processed
+
+    def _calculate_chain_cover_txn(
+        self,
+        txn: Cursor,
+        last_room_id: str,
+        last_depth: int,
+        last_stream: int,
+        batch_size: Optional[int],
+        single_room: bool,
+    ) -> _CalculateChainCover:
+        """Calculate the chain cover for `batch_size` events, ordered by
+        `(room_id, depth, stream)`.
+
+        Args:
+            txn,
+            last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+                tuple to fetch results after.
+            batch_size: The maximum number of events to process. If None then
+                no limit.
+            single_room: Whether to calculate the index for just the given
+                room.
+        """
+
+        # Get the next set of events in the room (that we haven't already
+        # computed chain cover for). We do this in topological order.
+
+        # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+        # comparison, but that is not supported on older SQLite versions
+        tuple_clause, tuple_args = make_tuple_comparison_clause(
+            self.database_engine,
+            [
+                ("events.room_id", last_room_id),
+                ("topological_ordering", last_depth),
+                ("stream_ordering", last_stream),
+            ],
+        )
+
+        extra_clause = ""
+        if single_room:
+            extra_clause = "AND events.room_id = ?"
+            tuple_args.append(last_room_id)
+
+        sql = """
+            SELECT
+                event_id, state_events.type, state_events.state_key,
+                topological_ordering, stream_ordering,
+                events.room_id
+            FROM events
+            INNER JOIN state_events USING (event_id)
+            LEFT JOIN event_auth_chains USING (event_id)
+            LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+            WHERE event_auth_chains.event_id IS NULL
+                AND event_auth_chain_to_calculate.event_id IS NULL
+                AND %(tuple_cmp)s
+                %(extra)s
+            ORDER BY events.room_id, topological_ordering, stream_ordering
+            %(limit)s
+        """ % {
+            "tuple_cmp": tuple_clause,
+            "limit": "LIMIT ?" if batch_size is not None else "",
+            "extra": extra_clause,
+        }
+
+        if batch_size is not None:
+            tuple_args.append(batch_size)
+
+        txn.execute(sql, tuple_args)
+        rows = txn.fetchall()
+
+        # Put the results in the necessary format for
+        # `_add_chain_cover_index`
+        event_to_room_id = {row[0]: row[5] for row in rows}
+        event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+        # Calculate the new last position we've processed up to.
+        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
+        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
+        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+
+        # Map from room_id to last depth/stream_ordering processed for the room,
+        # excluding the last room (which we're likely still processing). We also
+        # need to include the room passed in if it's not included in the result
+        # set (as we then know we've processed all events in said room).
+        #
+        # This is the set of rooms that we can now safely flip the
+        # `has_auth_chain_index` bit for.
+        finished_rooms = {
+            row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+        }
+        if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+            finished_rooms[last_room_id] = (last_depth, last_stream)
+
+        count = len(rows)
+
+        # We also need to fetch the auth events for them.
+        auth_events = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth",
+            column="event_id",
+            iterable=event_to_room_id,
+            keyvalues={},
+            retcols=("event_id", "auth_id"),
+        )
+
+        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        for row in auth_events:
+            event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+        # Calculate and persist the chain cover index for this set of events.
+        #
+        # Annoyingly we need to gut wrench into the persit event store so that
+        # we can reuse the function to calculate the chain cover for rooms.
+        PersistEventsStore._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+        return _CalculateChainCover(
+            room_id=new_last_room_id,
+            depth=new_last_depth,
+            stream=new_last_stream,
+            processed_count=count,
+            finished_room_map=finished_rooms,
+        )
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..0ac1da9c35
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+    async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+        """Delete any extra forward extremities for a room.
+
+        Invalidates the "get_latest_event_ids_in_room" cache if any forward
+        extremities were deleted.
+
+        Returns count deleted.
+        """
+
+        def delete_forward_extremities_for_room_txn(txn):
+            # First we need to get the event_id to not delete
+            sql = """
+                SELECT event_id FROM event_forward_extremities
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+                ORDER BY stream_ordering DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (room_id,))
+            rows = txn.fetchall()
+            try:
+                event_id = rows[0][0]
+                logger.debug(
+                    "Found event_id %s as the forward extremity to keep for room %s",
+                    event_id,
+                    room_id,
+                )
+            except KeyError:
+                msg = "No forward extremity event found for room %s" % room_id
+                logger.warning(msg)
+                raise SynapseError(400, msg)
+
+            # Now delete the extra forward extremities
+            sql = """
+                DELETE FROM event_forward_extremities
+                WHERE event_id != ? AND room_id = ?
+            """
+
+            txn.execute(sql, (event_id, room_id))
+            logger.info(
+                "Deleted %s extra forward extremities for room %s",
+                txn.rowcount,
+                room_id,
+            )
+
+            if txn.rowcount > 0:
+                # Invalidate the cache
+                self._invalidate_cache_and_stream(
+                    txn, self.get_latest_event_ids_in_room, (room_id,),
+                )
+
+            return txn.rowcount
+
+        return await self.db_pool.runInteraction(
+            "delete_forward_extremities_for_room",
+            delete_forward_extremities_for_room_txn,
+        )
+
+    async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+        """Get list of forward extremities for a room."""
+
+        def get_forward_extremities_for_room_txn(txn):
+            sql = """
+                SELECT event_id, state_group, depth, received_ts
+                FROM event_forward_extremities
+                INNER JOIN event_to_state_groups USING (event_id)
+                INNER JOIN events USING (room_id, event_id)
+                WHERE room_id = ?
+            """
+
+            txn.execute(sql, (room_id,))
+            return self.db_pool.cursor_to_dict(txn)
+
+        return await self.db_pool.runInteraction(
+            "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+        )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4732685f6e..71d823be72 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="events",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_stream_seq",
                 writers=hs.config.worker.writers.events,
             )
@@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
                 db=database,
                 stream_name="backfill",
                 instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
+                tables=[("events", "instance_name", "stream_ordering")],
                 sequence_name="events_backfill_stream_seq",
                 positive=False,
                 writers=hs.config.worker.writers.events,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index f8f4bb9b3f..04ac2d0ced 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.keys import FetchKeyResult
+from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
@@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
     )
     async def get_server_verify_keys(
         self, server_name_and_key_ids: Iterable[Tuple[str, str]]
-    ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
+    ) -> Dict[Tuple[str, str], FetchKeyResult]:
         """
         Args:
             server_name_and_key_ids:
@@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
         """
         keys = {}
 
-        def _get_keys(txn, batch):
+        def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
             """Processes a batch of keys to fetch, and adds the result to `keys`."""
 
             # batch_iter always returns tuples so it's safe to do len(batch)
@@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
                     # `ts_valid_until_ms`.
                     ts_valid_until_ms = 0
 
-                res = FetchKeyResult(
+                keys[(server_name, key_id)] = FetchKeyResult(
                     verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
                     valid_until_ts=ts_valid_until_ms,
                 )
-                keys[(server_name, key_id)] = res
 
-        def _txn(txn):
+        def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
             for batch in batch_iter(server_name_and_key_ids, 50):
                 _get_keys(txn, batch)
             return keys
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4b2f224718..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_local_media_before(
         self, before_ts: int, size_gt: int, keep_profiles: bool,
-    ) -> Optional[List[str]]:
+    ) -> List[str]:
 
         # to find files that have never been accessed (last_access_ts IS NULL)
         # compare with `created_ts`
@@ -416,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_origin = ? AND media_id = ?"
             )
 
-            txn.executemany(
+            txn.execute_batch(
                 sql,
                 (
                     (time_ms, media_origin, media_id)
@@ -429,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 " WHERE media_id = ?"
             )
 
-            txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+            txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
 
         return await self.db_pool.runInteraction(
             "update_cached_last_access_time", update_cache_txn
@@ -556,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache", _delete_url_cache_txn
@@ -585,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_media_txn(txn):
             sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
             sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
-            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+            txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
 
         return await self.db_pool.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             (x[0] - 1) * x[1] for x in res if x[1]
         )
 
+    async def count_daily_e2ee_messages(self):
+        """
+        Returns an estimate of the number of messages sent in the last day.
+
+        If it has been significantly less or more than one day since the last
+        call to this function, it will return None.
+        """
+
+        def _count_messages(txn):
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
+
+    async def count_daily_sent_e2ee_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then thats your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_sent_e2ee_messages", _count_messages
+        )
+
+    async def count_daily_active_e2ee_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.encrypted'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            (count,) = txn.fetchone()
+            return count
+
+        return await self.db_pool.runInteraction(
+            "count_daily_active_e2ee_rooms", _count
+        )
+
     async def count_daily_messages(self):
         """
         Returns an estimate of the number of messages sent in the last day.
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 0e25ca3d7a..54ef0f1f54 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -82,7 +82,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def set_profile_avatar_url(
-        self, user_localpart: str, new_avatar_url: str
+        self, user_localpart: str, new_avatar_url: Optional[str]
     ) -> None:
         await self.db_pool.simple_update_one(
             table="profiles",
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         )
 
         # Update backward extremeties
-        txn.executemany(
+        txn.execute_batch(
             "INSERT INTO event_backward_extremities (room_id, event_id)"
             " VALUES (?, ?)",
             [(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7997242d90..2687ef3e43 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -15,18 +15,31 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, Iterator, List, Tuple
-
-from canonicaljson import encode_canonical_json
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
 
+from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
+from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class PusherWorkerStore(SQLBaseStore):
-    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+        )
+
+    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
         """JSON-decode the data in the rows returned from the `pushers` table
 
         Drops any rows whose data cannot be decoded
@@ -44,21 +57,23 @@ class PusherWorkerStore(SQLBaseStore):
                 )
                 continue
 
-            yield r
+            yield PusherConfig(**r)
 
-    async def user_has_pusher(self, user_id):
+    async def user_has_pusher(self, user_id: str) -> bool:
         ret = await self.db_pool.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
 
-    def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
-        return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+    async def get_pushers_by_app_id_and_pushkey(
+        self, app_id: str, pushkey: str
+    ) -> Iterator[PusherConfig]:
+        return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
 
-    def get_pushers_by_user_id(self, user_id):
-        return self.get_pushers_by({"user_name": user_id})
+    async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+        return await self.get_pushers_by({"user_name": user_id})
 
-    async def get_pushers_by(self, keyvalues):
+    async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
         ret = await self.db_pool.simple_select_list(
             "pushers",
             keyvalues,
@@ -83,7 +98,7 @@ class PusherWorkerStore(SQLBaseStore):
         )
         return self._decode_pushers_rows(ret)
 
-    async def get_all_pushers(self):
+    async def get_all_pushers(self) -> Iterator[PusherConfig]:
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +174,16 @@ class PusherWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=1, max_entries=15000)
-    async def get_if_user_has_pusher(self, user_id):
+    async def get_if_user_has_pusher(self, user_id: str):
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     @cachedList(
         cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
     )
-    async def get_if_users_have_pushers(self, user_ids):
+    async def get_if_users_have_pushers(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, bool]:
         rows = await self.db_pool.simple_select_many_batch(
             table="pushers",
             column="user_name",
@@ -224,7 +241,7 @@ class PusherWorkerStore(SQLBaseStore):
         return bool(updated)
 
     async def update_pusher_failing_since(
-        self, app_id, pushkey, user_id, failing_since
+        self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
     ) -> None:
         await self.db_pool.simple_update(
             table="pushers",
@@ -233,7 +250,9 @@ class PusherWorkerStore(SQLBaseStore):
             desc="update_pusher_failing_since",
         )
 
-    async def get_throttle_params_by_room(self, pusher_id):
+    async def get_throttle_params_by_room(
+        self, pusher_id: str
+    ) -> Dict[str, ThrottleParams]:
         res = await self.db_pool.simple_select_list(
             "pusher_throttle",
             {"pusher": pusher_id},
@@ -243,43 +262,44 @@ class PusherWorkerStore(SQLBaseStore):
 
         params_by_room = {}
         for row in res:
-            params_by_room[row["room_id"]] = {
-                "last_sent_ts": row["last_sent_ts"],
-                "throttle_ms": row["throttle_ms"],
-            }
+            params_by_room[row["room_id"]] = ThrottleParams(
+                row["last_sent_ts"], row["throttle_ms"],
+            )
 
         return params_by_room
 
-    async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+    async def set_throttle_params(
+        self, pusher_id: str, room_id: str, params: ThrottleParams
+    ) -> None:
         # no need to lock because `pusher_throttle` has a primary key on
         # (pusher, room_id) so simple_upsert will retry
         await self.db_pool.simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
-            params,
+            {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
             desc="set_throttle_params",
             lock=False,
         )
 
 
 class PusherStore(PusherWorkerStore):
-    def get_pushers_stream_token(self):
+    def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
     async def add_pusher(
         self,
-        user_id,
-        access_token,
-        kind,
-        app_id,
-        app_display_name,
-        device_display_name,
-        pushkey,
-        pushkey_ts,
-        lang,
-        data,
-        last_stream_ordering,
-        profile_tag="",
+        user_id: str,
+        access_token: Optional[int],
+        kind: str,
+        app_id: str,
+        app_display_name: str,
+        device_display_name: str,
+        pushkey: str,
+        pushkey_ts: int,
+        lang: Optional[str],
+        data: Optional[JsonDict],
+        last_stream_ordering: int,
+        profile_tag: str = "",
     ) -> None:
         async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
@@ -294,7 +314,7 @@ class PusherStore(PusherWorkerStore):
                     "device_display_name": device_display_name,
                     "ts": pushkey_ts,
                     "lang": lang,
-                    "data": bytearray(encode_canonical_json(data)),
+                    "data": json_encoder.encode(data),
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
@@ -311,20 +331,22 @@ class PusherStore(PusherWorkerStore):
                 # invalidate, since we the user might not have had a pusher before
                 await self.db_pool.runInteraction(
                     "add_pusher",
-                    self._invalidate_cache_and_stream,
+                    self._invalidate_cache_and_stream,  # type: ignore
                     self.get_if_user_has_pusher,
                     (user_id,),
                 )
 
     async def delete_pusher_by_app_id_pushkey_user_id(
-        self, app_id, pushkey, user_id
+        self, app_id: str, pushkey: str, user_id: str
     ) -> None:
         def delete_pusher_txn(txn, stream_id):
-            self._invalidate_cache_and_stream(
+            self._invalidate_cache_and_stream(  # type: ignore
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self.db_pool.simple_delete_one_txn(
+            # It is expected that there is exactly one pusher to delete, but
+            # if it isn't there (or there are multiple) delete them all.
+            self.db_pool.simple_delete_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1e7949a323..e4843a202c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,15 +14,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
 import logging
 from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet import defer
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-# The ABCMeta metaclass ensures that it cannot be instantiated without
-# the abstract methods being implemented.
-class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
-    """This is an abstract base class where subclasses must implement
-    `get_max_receipt_stream_id` which can be called in the initializer.
-    """
-
+class ReceiptsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
+        self._instance_name = hs.get_instance_name()
+
+        if isinstance(database.engine, PostgresEngine):
+            self._can_write_to_receipts = (
+                self._instance_name in hs.config.worker.writers.receipts
+            )
+
+            self._receipts_id_gen = MultiWriterIdGenerator(
+                db_conn=db_conn,
+                db=database,
+                stream_name="receipts",
+                instance_name=self._instance_name,
+                tables=[("receipts_linearized", "instance_name", "stream_id")],
+                sequence_name="receipts_sequence",
+                writers=hs.config.worker.writers.receipts,
+            )
+        else:
+            self._can_write_to_receipts = True
+
+            # We shouldn't be running in worker mode with SQLite, but its useful
+            # to support it for unit tests.
+            #
+            # If this process is the writer than we need to use
+            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+            # updated over replication. (Multiple writers are not supported for
+            # SQLite).
+            if hs.get_instance_name() in hs.config.worker.writers.receipts:
+                self._receipts_id_gen = StreamIdGenerator(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+            else:
+                self._receipts_id_gen = SlavedIdTracker(
+                    db_conn, "receipts_linearized", "stream_id"
+                )
+
         super().__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
-    @abc.abstractmethod
     def get_max_receipt_stream_id(self):
         """Get the current max stream ID for receipts stream
 
         Returns:
             int
         """
-        raise NotImplementedError()
+        return self._receipts_id_gen.get_current_token()
 
     @cached()
     async def get_users_with_read_receipts_in_room(self, room_id):
@@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
-
-class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
+    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+        self.get_receipts_for_user.invalidate((user_id, receipt_type))
+        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self.get_last_receipt_event_id_for_user.invalidate(
+            (user_id, room_id, receipt_type)
         )
+        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+        self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == ReceiptsStream.NAME:
+            self._receipts_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.invalidate_caches_for_receipt(
+                    row.room_id, row.receipt_type, row.user_id
+                )
+                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def insert_linearized_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
         """
+        assert self._can_write_to_receipts
+
         res = self.db_pool.simple_select_one_txn(
             txn,
             table="events",
@@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
                     )
                     return None
 
-        txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
-        txn.call_after(
-            self._invalidate_get_users_with_receipts_in_room,
-            room_id,
-            receipt_type,
-            user_id,
-        )
-        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
-        # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(
-            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
+            self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
         )
 
         txn.call_after(
             self._receipts_stream_cache.entity_has_changed, room_id, stream_id
         )
 
-        txn.call_after(
-            self.get_last_receipt_event_id_for_user.invalidate,
-            (user_id, room_id, receipt_type),
-        )
-
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_linearized",
@@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
         Automatically does conversion between linearized and graph
         representations.
         """
+        assert self._can_write_to_receipts
+
         if not event_ids:
             return None
 
@@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     async def insert_graph_receipt(
         self, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         return await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
@@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
     def insert_graph_receipt_txn(
         self, txn, room_id, receipt_type, user_id, event_ids, data
     ):
+        assert self._can_write_to_receipts
+
         txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
         txn.call_after(
             self._invalidate_get_users_with_receipts_in_room,
@@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "data": json_encoder.encode(data),
             },
         )
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fedb8a6c26..8405dd460f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
+    async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+        """Sets whether a user shadow-banned.
+
+        Args:
+            user: user ID of the user to test
+            shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+        """
+
+        def set_shadow_banned_txn(txn):
+            self.db_pool.simple_update_one_txn(
+                txn,
+                table="users",
+                keyvalues={"name": user.to_string()},
+                updatevalues={"shadow_banned": shadow_banned},
+            )
+            # In order for this to apply immediately, clear the cache for this user.
+            tokens = self.db_pool.simple_select_onecol_txn(
+                txn,
+                table="access_tokens",
+                keyvalues={"user_id": user.to_string()},
+                retcol="token",
+            )
+            for token in tokens:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (token,)
+                )
+
+        await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
     def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,
@@ -443,6 +472,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
+    async def record_user_external_id(
+        self, auth_provider: str, external_id: str, user_id: str
+    ) -> None:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        await self.db_pool.simple_insert(
+            table="user_external_ids",
+            values={
+                "auth_provider": auth_provider,
+                "external_id": external_id,
+                "user_id": user_id,
+            },
+            desc="record_user_external_id",
+        )
+
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
     ) -> Optional[str]:
@@ -463,6 +512,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_user_by_external_id",
         )
 
+    async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+        """Look up external ids for the given user
+
+        Args:
+            mxid: the MXID to be looked up
+
+        Returns:
+            Tuples of (auth_provider, external_id)
+        """
+        res = await self.db_pool.simple_select_list(
+            table="user_external_ids",
+            keyvalues={"user_id": mxid},
+            retcols=("auth_provider", "external_id"),
+            desc="get_external_ids_by_user",
+        )
+        return [(r["auth_provider"], r["external_id"]) for r in res]
+
     async def count_all_users(self):
         """Counts all users registered on the homeserver."""
 
@@ -926,6 +992,42 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="del_user_pending_deactivation",
         )
 
+    async def get_access_token_last_validated(self, token_id: int) -> int:
+        """Retrieves the time (in milliseconds) of the last validation of an access token.
+
+        Args:
+            token_id: The ID of the access token to update.
+        Raises:
+            StoreError if the access token was not found.
+
+        Returns:
+            The last validation time.
+        """
+        result = await self.db_pool.simple_select_one_onecol(
+            "access_tokens", {"id": token_id}, "last_validated"
+        )
+
+        # If this token has not been validated (since starting to track this),
+        # return 0 instead of None.
+        return result or 0
+
+    async def update_access_token_last_validated(self, token_id: int) -> None:
+        """Updates the last time an access token was validated.
+
+        Args:
+            token_id: The ID of the access token to update.
+        Raises:
+            StoreError if there was a problem updating this.
+        """
+        now = self._clock.time_msec()
+
+        await self.db_pool.simple_update_one(
+            "access_tokens",
+            {"id": token_id},
+            {"last_validated": now},
+            desc="update_access_token_last_validated",
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -963,6 +1065,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "user_external_ids_user_id_idx",
+            index_name="user_external_ids_user_id_idx",
+            table="user_external_ids",
+            columns=["user_id"],
+            unique=False,
+        )
+
     async def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
@@ -1043,7 +1153,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 FROM user_threepids
             """
 
-            txn.executemany(sql, [(id_server,) for id_server in id_servers])
+            txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
             await self.db_pool.runInteraction(
@@ -1125,6 +1235,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             The token ID
         """
         next_id = self._access_tokens_id_gen.get_next()
+        now = self._clock.time_msec()
 
         await self.db_pool.simple_insert(
             "access_tokens",
@@ -1135,6 +1246,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 "device_id": device_id,
                 "valid_until_ms": valid_until_ms,
                 "puppets_user_id": puppets_user_id,
+                "last_validated": now,
             },
             desc="add_access_token_to_user",
         )
@@ -1308,26 +1420,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-    async def record_user_external_id(
-        self, auth_provider: str, external_id: str, user_id: str
-    ) -> None:
-        """Record a mapping from an external user id to a mxid
-
-        Args:
-            auth_provider: identifier for the remote auth provider
-            external_id: id on that system
-            user_id: complete mxid that it is mapped to
-        """
-        await self.db_pool.simple_insert(
-            table="user_external_ids",
-            values={
-                "auth_provider": auth_provider,
-                "external_id": external_id,
-                "user_id": user_id,
-            },
-            desc="record_user_external_id",
-        )
-
     async def user_set_password_hash(
         self, user_id: str, password_hash: Optional[str]
     ) -> None:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6b89db15c9..a9fcb5f59c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -16,7 +16,6 @@
 
 import collections
 import logging
-import re
 from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
@@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
+from synapse.util.stringutils import MXC_REGEX
 
 logger = logging.getLogger(__name__)
 
@@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore):
         return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
-            retcols=("room_id", "is_public", "creator"),
+            retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
             desc="get_room",
             allow_none=True,
         )
@@ -379,14 +379,14 @@ class RoomWorkerStore(SQLBaseStore):
         # Filter room names by a string
         where_statement = ""
         if search_term:
-            where_statement = "WHERE state.name LIKE ?"
+            where_statement = "WHERE LOWER(state.name) LIKE ?"
 
             # Our postgres db driver converts ? -> %s in SQL strings as that's the
             # placeholder for postgres.
             # HOWEVER, if you put a % into your SQL then everything goes wibbly.
             # To get around this, we're going to surround search_term with %'s
             # before giving it to the database in python instead
-            search_term = "%" + search_term + "%"
+            search_term = "%" + search_term.lower() + "%"
 
         # Set ordering
         if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore):
             The local and remote media as a lists of tuples where the key is
             the hostname and the value is the media ID.
         """
-        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
         sql = """
             SELECT stream_ordering, json FROM events
             JOIN event_json USING (room_id, event_id)
@@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore):
                 for url in (content_url, thumbnail_url):
                     if not url:
                         continue
-                    matches = mxc_re.match(url)
+                    matches = MXC_REGEX.match(url)
                     if matches:
                         hostname = matches.group(1)
                         media_id = matches.group(2)
@@ -1166,6 +1164,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         # It's overridden by RoomStore for the synapse master.
         raise NotImplementedError()
 
+    async def has_auth_chain_index(self, room_id: str) -> bool:
+        """Check if the room has (or can have) a chain cover index.
+
+        Defaults to True if we don't have an entry in `rooms` table nor any
+        events for the room.
+        """
+
+        has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcol="has_auth_chain_index",
+            desc="has_auth_chain_index",
+            allow_none=True,
+        )
+
+        if has_auth_chain_index:
+            return True
+
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        max_ordering = await self.db_pool.simple_select_one_onecol(
+            table="events",
+            keyvalues={"room_id": room_id},
+            retcol="MAX(stream_ordering)",
+            allow_none=True,
+            desc="upsert_room_on_join",
+        )
+
+        return max_ordering is None
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1179,12 +1208,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         Called when we join a room over federation, and overwrites any room version
         currently in the table.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="upsert_room_on_join",
             table="rooms",
             keyvalues={"room_id": room_id},
             values={"room_version": room_version.identifier},
-            insertion_values={"is_public": False, "creator": ""},
+            insertion_values={
+                "is_public": False,
+                "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
+            },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
             lock=False,
@@ -1219,6 +1257,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         "creator": room_creator_user_id,
                         "is_public": is_public,
                         "room_version": room_version.identifier,
+                        "has_auth_chain_index": True,
                     },
                 )
                 if is_public:
@@ -1247,6 +1286,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         When we receive an invite or any other event over federation that may relate to a room
         we are not in, store the version of the room if we don't already know the room version.
         """
+        # It's possible that we already have events for the room in our DB
+        # without a corresponding room entry. If we do then we don't want to
+        # mark the room as having an auth chain cover index.
+        has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
         await self.db_pool.simple_upsert(
             desc="maybe_store_room_on_outlier_membership",
             table="rooms",
@@ -1256,6 +1300,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 "room_version": room_version.identifier,
                 "is_public": False,
                 "creator": "",
+                "has_auth_chain_index": has_auth_chain_index,
             },
             # rooms has a unique constraint on room_id, so no need to lock when doing an
             # emulated upsert.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             "max_stream_id_exclusive", self._stream_order_on_start + 1
         )
 
-        INSERT_CLUMP_SIZE = 1000
-
         def add_membership_profile_txn(txn):
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
                 UPDATE room_memberships SET display_name = ?, avatar_url = ?
                 WHERE event_id = ? AND room_id = ?
             """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
+            txn.execute_batch(to_update_sql, to_update)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 0000000000..8f5e65aa71
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
new file mode 100644
index 0000000000..1a101cd5eb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The last time this access token was "validated" (i.e. logged in or succeeded
+-- at user-interactive authentication).
+ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/27local_invites.sql b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
new file mode 100644
index 0000000000..44b2a0572f
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/27local_invites.sql
@@ -0,0 +1,18 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is unused since Synapse v1.17.0.
+DROP TABLE local_invites;
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
new file mode 100644
index 0000000000..de57645019
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE access_tokens DROP COLUMN last_used;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
new file mode 100644
index 0000000000..ee0e3521bf
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28drop_last_used_column.sql.sqlite
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ -- Dropping last_used column from access_tokens table.
+
+CREATE TABLE access_tokens2 (
+    id BIGINT PRIMARY KEY, 
+    user_id TEXT NOT NULL, 
+    device_id TEXT, 
+    token TEXT NOT NULL,
+    valid_until_ms BIGINT,
+    puppets_user_id TEXT,
+    last_validated BIGINT,
+    UNIQUE(token) 
+);
+
+INSERT INTO access_tokens2(id, user_id, device_id, token)
+    SELECT id, user_id, device_id, token FROM access_tokens;
+
+DROP TABLE access_tokens;
+ALTER TABLE access_tokens2 RENAME TO access_tokens;
+
+CREATE INDEX access_tokens_device_id ON access_tokens (user_id, device_id);
+
+
+-- Re-adding foreign key reference in event_txn_id table
+
+CREATE TABLE event_txn_id2 (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    token_id BIGINT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (token_id)
+        REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+INSERT INTO event_txn_id2(event_id, room_id, user_id, token_id, txn_id, inserted_ts)
+    SELECT event_id, room_id, user_id, token_id, txn_id, inserted_ts FROM event_txn_id;
+
+DROP TABLE event_txn_id;
+ALTER TABLE event_txn_id2 RENAME TO event_txn_id;
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
new file mode 100644
index 0000000000..9c95646281
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/28rejected_events_metadata.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5828, 'rejected_events_metadata', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
new file mode 100644
index 0000000000..9e8f35c1d2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -0,0 +1,82 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration denormalises the account_data table into an ignored users table.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage._base import db_to_json
+from synapse.storage.engines import BaseDatabaseEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    logger.info("Creating ignored_users table")
+    execute_statements_from_stream(cur, StringIO(_create_commands))
+
+    # We now upgrade existing data, if any. We don't do this in `run_upgrade` as
+    # we a) want to run these before adding constraints and b) `run_upgrade` is
+    # not run on empty databases.
+    insert_sql = """
+    INSERT INTO ignored_users (ignorer_user_id, ignored_user_id) VALUES (?, ?)
+    """
+
+    logger.info("Converting existing ignore lists")
+    cur.execute(
+        "SELECT user_id, content FROM account_data WHERE account_data_type = 'm.ignored_user_list'"
+    )
+    for user_id, content_json in cur.fetchall():
+        content = db_to_json(content_json)
+
+        # The content should be the form of a dictionary with a key
+        # "ignored_users" pointing to a dictionary with keys of ignored users.
+        #
+        # { "ignored_users": "@someone:example.org": {} }
+        ignored_users = content.get("ignored_users", {})
+        if isinstance(ignored_users, dict) and ignored_users:
+            cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
+
+    # Add indexes after inserting data for efficiency.
+    logger.info("Adding constraints to ignored_users table")
+    execute_statements_from_stream(cur, StringIO(_constraints_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_create_commands = """
+-- Users which are ignored when calculating push notifications. This data is
+-- denormalized from account data.
+CREATE TABLE IF NOT EXISTS ignored_users(
+    ignorer_user_id TEXT NOT NULL,  -- The user ID of the user who is ignoring another user. (This is a local user.)
+    ignored_user_id TEXT NOT NULL  -- The user ID of the user who is being ignored. (This is a local or remote user.)
+);
+"""
+
+_constraints_commands = """
+CREATE UNIQUE INDEX ignored_users_uniqueness ON ignored_users (ignorer_user_id, ignored_user_id);
+
+-- Add an index on ignored_users since look-ups are done to get all ignorers of an ignored user.
+CREATE INDEX ignored_users_ignored_user_id ON ignored_users (ignored_user_id);
+"""
diff --git a/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
new file mode 100644
index 0000000000..d781a92fec
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE device_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT;
+ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
new file mode 100644
index 0000000000..45a845a3a5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence;
+
+-- We need to take the max across both device_inbox and device_federation_outbox
+-- tables as they share the ID generator
+SELECT setval('device_inbox_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox)
+    )
+));
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
new file mode 100644
index 0000000000..729196cfd5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
@@ -0,0 +1,52 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- See docs/auth_chain_difference_algorithm.md
+
+CREATE TABLE event_auth_chains (
+  event_id TEXT PRIMARY KEY,
+  chain_id BIGINT NOT NULL,
+  sequence_number BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
+
+
+CREATE TABLE event_auth_chain_links (
+  origin_chain_id BIGINT NOT NULL,
+  origin_sequence_number BIGINT NOT NULL,
+
+  target_chain_id BIGINT NOT NULL,
+  target_sequence_number BIGINT NOT NULL
+);
+
+
+CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
+
+
+-- Events that we have persisted but not calculated auth chains for,
+-- e.g. out of band memberships (where we don't have the auth chain)
+CREATE TABLE event_auth_chain_to_calculate (
+  event_id TEXT PRIMARY KEY,
+  room_id TEXT NOT NULL,
+  type TEXT NOT NULL,
+  state_key TEXT NOT NULL
+);
+
+CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
+
+
+-- Whether we've calculated the above index for a room.
+ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
new file mode 100644
index 0000000000..e8a035bbeb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
new file mode 100644
index 0000000000..64ab696cfe
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04drop_account_data.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS account_data_max_stream_id;
diff --git a/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
new file mode 100644
index 0000000000..fb71b360a0
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/05cache_invalidation.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is no longer used and was only kept until we bumped the schema version.
+DROP TABLE IF EXISTS cache_invalidation_stream;
diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
new file mode 100644
index 0000000000..fe3dca71dd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (5906, 'chain_cover', '{}', 'rejected_events_metadata');
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
new file mode 100644
index 0000000000..46abf8d562
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql
@@ -0,0 +1,20 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
+ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
+ALTER TABLE account_data ADD COLUMN instance_name TEXT;
+
+ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
new file mode 100644
index 0000000000..4a6e6c74f5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres
@@ -0,0 +1,32 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
+
+-- We need to take the max across all the account_data tables as they share the
+-- ID generator
+SELECT setval('account_data_sequence', (
+    SELECT GREATEST(
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
+        (SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
+    )
+));
+
+CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
+
+SELECT setval('receipts_sequence', (
+    SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
+));
diff --git a/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
new file mode 100644
index 0000000000..9f2b5ebc5a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We incorrectly populated these, so we delete them and let the
+-- MultiWriterIdGenerator repopulate it.
+DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data';
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
 
@@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
@@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.executemany(sql, args)
+            txn.execute_batch(sql, args)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
     async def search_rooms(
         self,
-        room_ids: List[str],
+        room_ids: Collection[str],
         search_term: str,
         keys: List[str],
         limit,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..d421d18f8d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 
 import logging
-from collections import Counter
 from enum import Enum
 from itertools import chain
 from typing import Any, Dict, List, Optional, Tuple
 
+from typing_extensions import Counter
+
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
@@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
+    async def get_earliest_token_for_stats(
+        self, stats_type: str, id: str
+    ) -> Optional[int]:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
         )
 
     async def bulk_update_stats_delta(
-        self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
+        self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
     ) -> None:
         """Bulk update stats tables for a given stream_id and updates the stats
         incremental position.
@@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
 
     async def get_changes_room_total_events_and_bytes(
         self, min_pos: int, max_pos: int
-    ) -> Dict[str, Dict[str, int]]:
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Fetches the counts of events in the given range of stream IDs.
 
         Args:
@@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
             max_pos,
         )
 
-    def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+    def get_changes_room_total_events_and_bytes_txn(
+        self, txn, low_pos: int, high_pos: int
+    ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
         """Gets the total_events and total_event_bytes counts for rooms and
         senders, in a range of stream_orderings (including backfilled events).
 
         Args:
             txn
-            low_pos (int): Low stream ordering
-            high_pos (int): High stream ordering
+            low_pos: Low stream ordering
+            high_pos: High stream ordering
 
         Returns:
-            tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
-            room and user deltas for total_events/total_event_bytes in the
+            The room and user deltas for total_events/total_event_bytes in the
             format of `stats_id` -> fields
         """
 
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 9f120d3cb6..50067eabfc 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
         )
         return {row["tag"]: db_to_json(row["content"]) for row in rows}
 
-
-class TagsStore(TagsWorkerStore):
     async def add_tag_to_room(
         self, user_id: str, room_id: str, tag: str, content: JsonDict
     ) -> int:
@@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
+
         content_json = json_encoder.encode(content)
 
         def add_tag_txn(txn, next_id):
@@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
+        assert self._can_write_to_account_data
 
         def remove_tag_txn(txn, next_id):
             sql = (
@@ -250,21 +251,12 @@ class TagsStore(TagsWorkerStore):
             room_id: The ID of the room.
             next_id: The the revision to advance to.
         """
+        assert self._can_write_to_account_data
 
         txn.call_after(
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
         )
 
-        # Note: This is only here for backwards compat to allow admins to
-        # roll back to a previous Synapse version. Next time we update the
-        # database version we can remove this table.
-        update_max_id_sql = (
-            "UPDATE account_data_max_stream_id"
-            " SET stream_id = ?"
-            " WHERE stream_id < ?"
-        )
-        txn.execute(update_max_id_sql, (next_id, next_id))
-
         update_sql = (
             "UPDATE room_tags_revisions"
             " SET stream_id = ?"
@@ -288,3 +280,7 @@ class TagsStore(TagsWorkerStore):
                 # which stream_id ends up in the table, as long as it is higher
                 # than the id that the client has.
                 pass
+
+
+class TagsStore(TagsWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 59207cadd4..cea595ff19 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -464,19 +464,17 @@ class TransactionStore(TransactionWorkerStore):
         txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
     ) -> List[str]:
         q = """
-            SELECT destination FROM destinations
-                WHERE destination IN (
-                    SELECT destination FROM destination_rooms
-                        WHERE destination_rooms.stream_ordering >
-                            destinations.last_successful_stream_ordering
-                )
-                AND destination > ?
-                AND (
-                    retry_last_ts IS NULL OR
-                    retry_last_ts + retry_interval < ?
-                )
-                ORDER BY destination
-                LIMIT 25
+            SELECT DISTINCT destination FROM destinations
+            INNER JOIN destination_rooms USING (destination)
+                WHERE
+                    stream_ordering > last_successful_stream_ordering
+                    AND destination > ?
+                    AND (
+                        retry_last_ts IS NULL OR
+                        retry_last_ts + retry_interval < ?
+                    )
+                    ORDER BY destination
+                    LIMIT 25
         """
         txn.execute(
             q,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index d87ceec6da..7b9729da09 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -17,7 +17,7 @@ import logging
 import re
 from typing import Any, Dict, Iterable, Optional, Set, Tuple
 
-from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
@@ -360,7 +360,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         if hist_vis_id:
             hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
             if hist_vis_ev:
-                if hist_vis_ev.content.get("history_visibility") == "world_readable":
+                if (
+                    hist_vis_ev.content.get("history_visibility")
+                    == HistoryVisibility.WORLD_READABLE
+                ):
                     return True
 
         return False
@@ -393,9 +396,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                     sql = """
                         INSERT INTO user_directory_search(user_id, vector)
                         VALUES (?,
-                            setweight(to_tsvector('english', ?), 'A')
-                            || setweight(to_tsvector('english', ?), 'D')
-                            || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            setweight(to_tsvector('simple', ?), 'A')
+                            || setweight(to_tsvector('simple', ?), 'D')
+                            || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
                         ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
                     """
                     txn.execute(
@@ -415,9 +418,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                         sql = """
                             INSERT INTO user_directory_search(user_id, vector)
                             VALUES (?,
-                                setweight(to_tsvector('english', ?), 'A')
-                                || setweight(to_tsvector('english', ?), 'D')
-                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                                setweight(to_tsvector('simple', ?), 'A')
+                                || setweight(to_tsvector('simple', ?), 'D')
+                                || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
                             )
                         """
                         txn.execute(
@@ -432,9 +435,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                     elif new_entry is False:
                         sql = """
                             UPDATE user_directory_search
-                            SET vector = setweight(to_tsvector('english', ?), 'A')
-                                || setweight(to_tsvector('english', ?), 'D')
-                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            SET vector = setweight(to_tsvector('simple', ?), 'A')
+                                || setweight(to_tsvector('simple', ?), 'D')
+                                || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
                             WHERE user_id = ?
                         """
                         txn.execute(
@@ -537,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             desc="get_user_in_directory",
         )
 
-    async def update_user_directory_stream_pos(self, stream_id: str) -> None:
+    async def update_user_directory_stream_pos(self, stream_id: int) -> None:
         await self.db_pool.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
@@ -761,7 +764,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
                 INNER JOIN user_directory AS d USING (user_id)
                 WHERE
                     %s
-                    AND vector @@ to_tsquery('english', ?)
+                    AND vector @@ to_tsquery('simple', ?)
                 ORDER BY
                     (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
                     * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
@@ -770,13 +773,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
                         3 * ts_rank_cd(
                             '{0.1, 0.1, 0.9, 1.0}',
                             vector,
-                            to_tsquery('english', ?),
+                            to_tsquery('simple', ?),
                             8
                         )
                         + ts_rank_cd(
                             '{0.1, 0.1, 0.9, 1.0}',
                             vector,
-                            to_tsquery('english', ?),
+                            to_tsquery('simple', ?),
                             8
                         )
                     )
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
         logger.info("[purge] removing redundant state groups")
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
-        txn.executemany(
+        txn.execute_batch(
             "DELETE FROM state_groups WHERE id = ?",
             ((sg,) for sg in state_groups_to_delete),
         )
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index afd10f7bae..c03871f393 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -17,11 +17,12 @@
 import logging
 
 import attr
+from signedjson.types import VerifyKey
 
 logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, frozen=True)
 class FetchKeyResult:
-    verify_key = attr.ib()  # VerifyKey: the key itself
-    valid_until_ts = attr.ib()  # int: how long we can use this key for
+    verify_key = attr.ib(type=VerifyKey)  # the key itself
+    valid_until_ts = attr.ib(type=int)  # how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 70e636b0ba..61fc49c69c 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,14 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+    Collection,
+    PersistedEventPosition,
+    RoomStreamToken,
+    StateMap,
+    get_domain_from_id,
+)
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -68,6 +75,21 @@ stale_forward_extremities_counter = Histogram(
     buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
 )
 
+state_resolutions_during_persistence = Counter(
+    "synapse_storage_events_state_resolutions_during_persistence",
+    "Number of times we had to do state res to calculate new current state",
+)
+
+potential_times_prune_extremities = Counter(
+    "synapse_storage_events_potential_times_prune_extremities",
+    "Number of times we might be able to prune extremities",
+)
+
+times_pruned_extremities = Counter(
+    "synapse_storage_events_times_pruned_extremities",
+    "Number of times we were actually be able to prune extremities",
+)
+
 
 class _EventPeristenceQueue:
     """Queues up events so that they can be persisted in bulk with only one
@@ -454,7 +476,15 @@ class EventsPersistenceStorage:
                                 latest_event_ids,
                                 new_latest_event_ids,
                             )
-                            current_state, delta_ids = res
+                            current_state, delta_ids, new_latest_event_ids = res
+
+                            # there should always be at least one forward extremity.
+                            # (except during the initial persistence of the send_join
+                            # results, in which case there will be no existing
+                            # extremities, so we'll `continue` above and skip this bit.)
+                            assert new_latest_event_ids, "No forward extremities left!"
+
+                            new_forward_extremeties[room_id] = new_latest_event_ids
 
                         # If either are not None then there has been a change,
                         # and we need to work out the delta (or use that
@@ -573,29 +603,35 @@ class EventsPersistenceStorage:
         self,
         room_id: str,
         events_context: List[Tuple[EventBase, EventContext]],
-        old_latest_event_ids: Iterable[str],
-        new_latest_event_ids: Iterable[str],
-    ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+        old_latest_event_ids: Set[str],
+        new_latest_event_ids: Set[str],
+    ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
         """Calculate the current state dict after adding some new events to
         a room
 
         Args:
-            room_id (str):
+            room_id:
                 room to which the events are being added. Used for logging etc
 
-            events_context (list[(EventBase, EventContext)]):
+            events_context:
                 events and contexts which are being added to the room
 
-            old_latest_event_ids (iterable[str]):
+            old_latest_event_ids:
                 the old forward extremities for the room.
 
-            new_latest_event_ids (iterable[str]):
+            new_latest_event_ids :
                 the new forward extremities for the room.
 
         Returns:
-            Returns a tuple of two state maps, the first being the full new current
-            state and the second being the delta to the existing current state.
-            If both are None then there has been no change.
+            Returns a tuple of two state maps and a set of new forward
+            extremities.
+
+            The first state map is the full new current state and the second
+            is the delta to the existing current state. If both are None then
+            there has been no change.
+
+            The function may prune some old entries from the set of new
+            forward extremities if it's safe to do so.
 
             If there has been a change then we only return the delta if its
             already been calculated. Conversely if we do know the delta then
@@ -672,7 +708,7 @@ class EventsPersistenceStorage:
         # If they old and new groups are the same then we don't need to do
         # anything.
         if old_state_groups == new_state_groups:
-            return None, None
+            return None, None, new_latest_event_ids
 
         if len(new_state_groups) == 1 and len(old_state_groups) == 1:
             # If we're going from one state group to another, lets check if
@@ -689,7 +725,7 @@ class EventsPersistenceStorage:
                 # the current state in memory then lets also return that,
                 # but it doesn't matter if we don't.
                 new_state = state_groups_map.get(new_state_group)
-                return new_state, delta_ids
+                return new_state, delta_ids, new_latest_event_ids
 
         # Now that we have calculated new_state_groups we need to get
         # their state IDs so we can resolve to a single state set.
@@ -701,7 +737,7 @@ class EventsPersistenceStorage:
         if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
             # state is.
-            return state_groups_map[new_state_groups.pop()], None
+            return state_groups_map[new_state_groups.pop()], None, new_latest_event_ids
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
@@ -734,7 +770,139 @@ class EventsPersistenceStorage:
             state_res_store=StateResolutionStore(self.main_store),
         )
 
-        return res.state, None
+        state_resolutions_during_persistence.inc()
+
+        # If the returned state matches the state group of one of the new
+        # forward extremities then we check if we are able to prune some state
+        # extremities.
+        if res.state_group and res.state_group in new_state_groups:
+            new_latest_event_ids = await self._prune_extremities(
+                room_id,
+                new_latest_event_ids,
+                res.state_group,
+                event_id_to_state_group,
+                events_context,
+            )
+
+        return res.state, None, new_latest_event_ids
+
+    async def _prune_extremities(
+        self,
+        room_id: str,
+        new_latest_event_ids: Set[str],
+        resolved_state_group: int,
+        event_id_to_state_group: Dict[str, int],
+        events_context: List[Tuple[EventBase, EventContext]],
+    ) -> Set[str]:
+        """See if we can prune any of the extremities after calculating the
+        resolved state.
+        """
+        potential_times_prune_extremities.inc()
+
+        # We keep all the extremities that have the same state group, and
+        # see if we can drop the others.
+        new_new_extrems = {
+            e
+            for e in new_latest_event_ids
+            if event_id_to_state_group[e] == resolved_state_group
+        }
+
+        dropped_extrems = set(new_latest_event_ids) - new_new_extrems
+
+        logger.debug("Might drop extremities: %s", dropped_extrems)
+
+        # We only drop events from the extremities list if:
+        #   1. we're not currently persisting them;
+        #   2. they're not our own events (or are dummy events); and
+        #   3. they're either:
+        #       1. over N hours old and more than N events ago (we use depth to
+        #          calculate); or
+        #       2. we are persisting an event from the same domain and more than
+        #          M events ago.
+        #
+        # The idea is that we don't want to drop events that are "legitimate"
+        # extremities (that we would want to include as prev events), only
+        # "stuck" extremities that are e.g. due to a gap in the graph.
+        #
+        # Note that we either drop all of them or none of them. If we only drop
+        # some of the events we don't know if state res would come to the same
+        # conclusion.
+
+        for ev, _ in events_context:
+            if ev.event_id in dropped_extrems:
+                logger.debug(
+                    "Not dropping extremities: %s is being persisted", ev.event_id
+                )
+                return new_latest_event_ids
+
+        dropped_events = await self.main_store.get_events(
+            dropped_extrems,
+            allow_rejected=True,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+        )
+
+        new_senders = {get_domain_from_id(e.sender) for e, _ in events_context}
+
+        one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+        current_depth = max(e.depth for e, _ in events_context)
+        for event in dropped_events.values():
+            # If the event is a local dummy event then we should check it
+            # doesn't reference any local events, as we want to reference those
+            # if we send any new events.
+            #
+            # Note we do this recursively to handle the case where a dummy event
+            # references a dummy event that only references remote events.
+            #
+            # Ideally we'd figure out a way of still being able to drop old
+            # dummy events that reference local events, but this is good enough
+            # as a first cut.
+            events_to_check = [event]
+            while events_to_check:
+                new_events = set()
+                for event_to_check in events_to_check:
+                    if self.is_mine_id(event_to_check.sender):
+                        if event_to_check.type != EventTypes.Dummy:
+                            logger.debug("Not dropping own event")
+                            return new_latest_event_ids
+                        new_events.update(event_to_check.prev_event_ids())
+
+                prev_events = await self.main_store.get_events(
+                    new_events,
+                    allow_rejected=True,
+                    redact_behaviour=EventRedactBehaviour.AS_IS,
+                )
+                events_to_check = prev_events.values()
+
+            if (
+                event.origin_server_ts < one_day_ago
+                and event.depth < current_depth - 100
+            ):
+                continue
+
+            # We can be less conservative about dropping extremities from the
+            # same domain, though we do want to wait a little bit (otherwise
+            # we'll immediately remove all extremities from a given server).
+            if (
+                get_domain_from_id(event.sender) in new_senders
+                and event.depth < current_depth - 20
+            ):
+                continue
+
+            logger.debug(
+                "Not dropping as too new and not in new_senders: %s", new_senders,
+            )
+
+            return new_latest_event_ids
+
+        times_pruned_extremities.inc()
+
+        logger.info(
+            "Pruning forward extremities in room %s: from %s -> %s",
+            room_id,
+            new_latest_event_ids,
+            new_new_extrems,
+        )
+        return new_new_extrems
 
     async def _calculate_state_delta(
         self, room_id: str, current_state: StateMap[str]
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..566ea19bae 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
 import os
 import re
 from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
 
 import attr
+from typing_extensions import Counter as CounterType
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import LoggingDatabaseConnection
@@ -34,10 +35,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-# XXX: If you're about to bump this to 59 (or higher) please create an update
-# that drops the unused `cache_invalidation_stream` table, as per #7436!
-# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
-SCHEMA_VERSION = 58
+SCHEMA_VERSION = 59
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -70,7 +68,7 @@ def prepare_database(
     db_conn: LoggingDatabaseConnection,
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
-    databases: Collection[str] = ["main", "state"],
+    databases: Collection[str] = ("main", "state"),
 ):
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
@@ -155,7 +153,9 @@ def prepare_database(
         raise
 
 
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+    cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
     """Sets up the physical database by finding a base set of "full schemas" and
     then applying any necessary deltas, including schemas from the given data
     stores.
@@ -188,10 +188,9 @@ def _setup_new_database(cur, database_engine, databases):
     folder as well those in the data stores specified.
 
     Args:
-        cur (Cursor): a database cursor
-        database_engine (DatabaseEngine)
-        databases (list[str]): The names of the databases to instantiate
-            on the given physical database.
+        cur: a database cursor
+        database_engine
+        databases: The names of the databases to instantiate on the given physical database.
     """
 
     # We're about to set up a brand new database so we check that its
@@ -199,12 +198,11 @@ def _setup_new_database(cur, database_engine, databases):
     database_engine.check_new_database(cur)
 
     current_dir = os.path.join(dir_path, "schema", "full_schemas")
-    directory_entries = os.listdir(current_dir)
 
     # First we find the highest full schema version we have
     valid_versions = []
 
-    for filename in directory_entries:
+    for filename in os.listdir(current_dir):
         try:
             ver = int(filename)
         except ValueError:
@@ -237,7 +235,7 @@ def _setup_new_database(cur, database_engine, databases):
         for database in databases
     )
 
-    directory_entries = []
+    directory_entries = []  # type: List[_DirectoryListing]
     for directory in directories:
         directory_entries.extend(
             _DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +273,15 @@ def _setup_new_database(cur, database_engine, databases):
 
 
 def _upgrade_existing_database(
-    cur,
-    current_version,
-    applied_delta_files,
-    upgraded,
-    database_engine,
-    config,
-    databases,
-    is_empty=False,
-):
+    cur: Cursor,
+    current_version: int,
+    applied_delta_files: List[str],
+    upgraded: bool,
+    database_engine: BaseDatabaseEngine,
+    config: Optional[HomeServerConfig],
+    databases: Collection[str],
+    is_empty: bool = False,
+) -> None:
     """Upgrades an existing physical database.
 
     Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +321,20 @@ def _upgrade_existing_database(
     for a version before applying those in the next version.
 
     Args:
-        cur (Cursor)
-        current_version (int): The current version of the schema.
-        applied_delta_files (list): A list of deltas that have already been
-            applied.
-        upgraded (bool): Whether the current version was generated by having
+        cur
+        current_version: The current version of the schema.
+        applied_delta_files: A list of deltas that have already been applied.
+        upgraded: Whether the current version was generated by having
             applied deltas or from full schema file. If `True` the function
             will never apply delta files for the given `current_version`, since
             the current_version wasn't generated by applying those delta files.
-        database_engine (DatabaseEngine)
-        config (synapse.config.homeserver.HomeServerConfig|None):
+        database_engine
+        config:
             None if we are initialising a blank database, otherwise the application
             config
-        databases (list[str]): The names of the databases to instantiate
+        databases: The names of the databases to instantiate
             on the given physical database.
-        is_empty (bool): Is this a blank database? I.e. do we need to run the
+        is_empty: Is this a blank database? I.e. do we need to run the
             upgrade portions of the delta scripts.
     """
     if is_empty:
@@ -358,6 +355,7 @@ def _upgrade_existing_database(
     if not is_empty and "main" in databases:
         from synapse.storage.databases.main import check_database_before_upgrade
 
+        assert config is not None
         check_database_before_upgrade(cur, database_engine, config)
 
     start_ver = current_version
@@ -374,7 +372,16 @@ def _upgrade_existing_database(
     specific_engine_extensions = (".sqlite", ".postgres")
 
     for v in range(start_ver, SCHEMA_VERSION + 1):
-        logger.info("Applying schema deltas for v%d", v)
+        if not is_worker:
+            logger.info("Applying schema deltas for v%d", v)
+
+            cur.execute("DELETE FROM schema_version")
+            cur.execute(
+                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
+                (v, True),
+            )
+        else:
+            logger.info("Checking schema deltas for v%d", v)
 
         # We need to search both the global and per data store schema
         # directories for schema updates.
@@ -388,10 +395,10 @@ def _upgrade_existing_database(
             )
 
         # Used to check if we have any duplicate file names
-        file_name_counter = Counter()
+        file_name_counter = Counter()  # type: CounterType[str]
 
         # Now find which directories have anything of interest.
-        directory_entries = []
+        directory_entries = []  # type: List[_DirectoryListing]
         for directory in directories:
             logger.debug("Looking for schema deltas in %s", directory)
             try:
@@ -445,11 +452,11 @@ def _upgrade_existing_database(
 
                 module_name = "synapse.storage.v%d_%s" % (v, root_name)
                 with open(absolute_path) as python_file:
-                    module = imp.load_source(module_name, absolute_path, python_file)
+                    module = imp.load_source(module_name, absolute_path, python_file)  # type: ignore
                 logger.info("Running script %s", relative_path)
-                module.run_create(cur, database_engine)
+                module.run_create(cur, database_engine)  # type: ignore
                 if not is_empty:
-                    module.run_upgrade(cur, database_engine, config=config)
+                    module.run_upgrade(cur, database_engine, config=config)  # type: ignore
             elif ext == ".pyc" or file_name == "__pycache__":
                 # Sometimes .pyc files turn up anyway even though we've
                 # disabled their generation; e.g. from distribution package
@@ -488,23 +495,18 @@ def _upgrade_existing_database(
                 (v, relative_path),
             )
 
-            cur.execute("DELETE FROM schema_version")
-            cur.execute(
-                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
-                (v, True),
-            )
-
     logger.info("Schema now up to date")
 
 
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+    txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
     """Apply the module schemas for the dynamic modules, if any
 
     Args:
         cur: database cursor
-        database_engine: synapse database engine class
-        config (synapse.config.homeserver.HomeServerConfig):
-            application config
+        database_engine:
+        config: application config
     """
     for (mod, _config) in config.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
         )
 
 
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+    cur: Cursor,
+    database_engine: BaseDatabaseEngine,
+    modname: str,
+    names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
     """Apply the module schemas for a single module
 
     Args:
         cur: database cursor
         database_engine: synapse database engine class
-        modname (str): fully qualified name of the module
-        names_and_streams (Iterable[(str, file)]): the names and streams of
-            schemas to be applied
+        modname: fully qualified name of the module
+        names_and_streams: the names and streams of schemas to be applied
     """
     cur.execute(
         "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         )
 
 
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment
 
@@ -594,17 +600,19 @@ def get_statements(f):
         statement_buffer = statements[-1].strip()
 
 
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
     with open(schema_path, "r") as f:
         execute_statements_from_stream(txn, f)
 
 
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
     for statement in get_statements(f):
         cur.execute(statement)
 
 
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+    txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
     # Bluntly try creating the schema_version tables.
     schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
     executescript(txn, schema_path)
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
     txn.execute("SELECT version, upgraded FROM schema_version")
     row = txn.fetchone()
     current_version = int(row[0]) if row else None
-    upgraded = bool(row[1]) if row else None
 
     if current_version:
         txn.execute(
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
             (current_version,),
         )
         applied_deltas = [d for d, in txn]
+        upgraded = bool(row[1])
         return current_version, applied_deltas, upgraded
 
     return None
@@ -634,5 +642,5 @@ class _DirectoryListing:
     `file_name` attr is kept first.
     """
 
-    file_name = attr.ib()
-    absolute_path = attr.ib()
+    file_name = attr.ib(type=str)
+    absolute_path = attr.ib(type=str)
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index bfa0a9fd06..6c359c1aae 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,7 +15,12 @@
 
 import itertools
 import logging
-from typing import Set
+from typing import TYPE_CHECKING, Set
+
+from synapse.storage.databases import Databases
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -24,10 +29,10 @@ class PurgeEventsStorage:
     """High level interface for purging rooms and event history.
     """
 
-    def __init__(self, hs, stores):
+    def __init__(self, hs: "HomeServer", stores: Databases):
         self.stores = stores
 
-    async def purge_room(self, room_id: str):
+    async def purge_room(self, room_id: str) -> None:
         """Deletes all record of a room
         """
 
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index cec96ad6a7..2564f34b47 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -14,10 +14,12 @@
 # limitations under the License.
 
 import logging
+from typing import Any, Dict, List, Optional, Tuple
 
 import attr
 
 from synapse.api.errors import SynapseError
+from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
@@ -27,18 +29,18 @@ class PaginationChunk:
     """Returned by relation pagination APIs.
 
     Attributes:
-        chunk (list): The rows returned by pagination
-        next_batch (Any|None): Token to fetch next set of results with, if
+        chunk: The rows returned by pagination
+        next_batch: Token to fetch next set of results with, if
             None then there are no more results.
-        prev_batch (Any|None): Token to fetch previous set of results with, if
+        prev_batch: Token to fetch previous set of results with, if
             None then there are no previous results.
     """
 
-    chunk = attr.ib()
-    next_batch = attr.ib(default=None)
-    prev_batch = attr.ib(default=None)
+    chunk = attr.ib(type=List[JsonDict])
+    next_batch = attr.ib(type=Optional[Any], default=None)
+    prev_batch = attr.ib(type=Optional[Any], default=None)
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         d = {"chunk": self.chunk}
 
         if self.next_batch:
@@ -59,25 +61,25 @@ class RelationPaginationToken:
     boundaries of the chunk as pagination tokens.
 
     Attributes:
-        topological (int): The topological ordering of the boundary event
-        stream (int): The stream ordering of the boundary event.
+        topological: The topological ordering of the boundary event
+        stream: The stream ordering of the boundary event.
     """
 
-    topological = attr.ib()
-    stream = attr.ib()
+    topological = attr.ib(type=int)
+    stream = attr.ib(type=int)
 
     @staticmethod
-    def from_string(string):
+    def from_string(string: str) -> "RelationPaginationToken":
         try:
             t, s = string.split("-")
             return RelationPaginationToken(int(t), int(s))
         except ValueError:
             raise SynapseError(400, "Invalid token")
 
-    def to_string(self):
+    def to_string(self) -> str:
         return "%d-%d" % (self.topological, self.stream)
 
-    def as_tuple(self):
+    def as_tuple(self) -> Tuple[Any, ...]:
         return attr.astuple(self)
 
 
@@ -89,23 +91,23 @@ class AggregationPaginationToken:
     aggregation groups, we can just use them as our pagination token.
 
     Attributes:
-        count (int): The count of relations in the boundar group.
-        stream (int): The MAX stream ordering in the boundary group.
+        count: The count of relations in the boundary group.
+        stream: The MAX stream ordering in the boundary group.
     """
 
-    count = attr.ib()
-    stream = attr.ib()
+    count = attr.ib(type=int)
+    stream = attr.ib(type=int)
 
     @staticmethod
-    def from_string(string):
+    def from_string(string: str) -> "AggregationPaginationToken":
         try:
             c, s = string.split("-")
             return AggregationPaginationToken(int(c), int(s))
         except ValueError:
             raise SynapseError(400, "Invalid token")
 
-    def to_string(self):
+    def to_string(self) -> str:
         return "%d-%d" % (self.count, self.stream)
 
-    def as_tuple(self):
+    def as_tuple(self) -> Tuple[Any, ...]:
         return attr.astuple(self)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 08a69f2f96..31ccbf23dc 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -12,9 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+    from synapse.storage.databases import Databases
+
 logger = logging.getLogger(__name__)
 
 # Used for generic functions below
@@ -330,10 +343,12 @@ class StateGroupStorage:
     """High level interface to fetching state for event.
     """
 
-    def __init__(self, hs, stores):
+    def __init__(self, hs: "HomeServer", stores: "Databases"):
         self.stores = stores
 
-    async def get_state_group_delta(self, state_group: int):
+    async def get_state_group_delta(
+        self, state_group: int
+    ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
@@ -341,8 +356,8 @@ class StateGroupStorage:
             state_group: The state group used to retrieve state deltas.
 
         Returns:
-            Tuple[Optional[int], Optional[StateMap[str]]]:
-                (prev_group, delta_ids)
+            A tuple of the previous group and a state map of the event IDs which
+            make up the delta between the old and new state groups.
         """
 
         return await self.stores.state.get_state_group_delta(state_group)
@@ -436,7 +451,7 @@ class StateGroupStorage:
 
     async def get_state_for_events(
         self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
 
@@ -472,7 +487,7 @@ class StateGroupStorage:
 
     async def get_state_ids_for_events(
         self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
         of the state events (as opposed to the events themselves)
@@ -500,7 +515,7 @@ class StateGroupStorage:
 
     async def get_state_for_event(
         self, event_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[EventBase]:
         """
         Get the state dict corresponding to a particular event
 
@@ -516,7 +531,7 @@ class StateGroupStorage:
 
     async def get_state_ids_for_event(
         self, event_id: str, state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 02d71302ea..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
 import heapq
 import logging
 import threading
-from collections import deque
+from collections import OrderedDict
 from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
 
 import attr
-from typing_extensions import Deque
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()  # type: Deque[int]
+
+        # We use this as an ordered set, as we want to efficiently append items,
+        # remove items and get the first item. Since we insert IDs in order, the
+        # insertion ordering will ensure its in the correct ordering.
+        #
+        # The key and values are the same, but we never look at the values.
+        self._unfinished_ids = OrderedDict()  # type: OrderedDict[int, int]
 
     def get_next(self):
         """
@@ -113,7 +118,7 @@ class StreamIdGenerator:
             self._current += self._step
             next_id = self._current
 
-            self._unfinished_ids.append(next_id)
+            self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
                 yield next_id
             finally:
                 with self._lock:
-                    self._unfinished_ids.remove(next_id)
+                    self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
@@ -140,7 +145,7 @@ class StreamIdGenerator:
             self._current += n * self._step
 
             for next_id in next_ids:
-                self._unfinished_ids.append(next_id)
+                self._unfinished_ids[next_id] = next_id
 
         @contextmanager
         def manager():
@@ -149,20 +154,20 @@ class StreamIdGenerator:
             finally:
                 with self._lock:
                     for next_id in next_ids:
-                        self._unfinished_ids.remove(next_id)
+                        self._unfinished_ids.pop(next_id)
 
         return _AsyncCtxManagerWrapper(manager())
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
 
         Returns:
-            int
+            The maximum stream id.
         """
         with self._lock:
             if self._unfinished_ids:
-                return self._unfinished_ids[0] - self._step
+                return next(iter(self._unfinished_ids)) - self._step
 
             return self._current
 
@@ -186,11 +191,12 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
-        stream_name: A name for the stream.
+        stream_name: A name for the stream, for use in the `stream_positions`
+            table. (Does not need to be the same as the replication stream name)
         instance_name: The name of this instance.
-        table: Database table associated with stream.
-        instance_column: Column that stores the row's writer's instance name
-        id_column: Column that stores the stream ID.
+        tables: List of tables associated with the stream. Tuple of table
+            name, column name that stores the writer's instance name, and
+            column name that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
         writers: A list of known writers to use to populate current positions
@@ -206,9 +212,7 @@ class MultiWriterIdGenerator:
         db: DatabasePool,
         stream_name: str,
         instance_name: str,
-        table: str,
-        instance_column: str,
-        id_column: str,
+        tables: List[Tuple[str, str, str]],
         sequence_name: str,
         writers: List[str],
         positive: bool = True,
@@ -260,15 +264,20 @@ class MultiWriterIdGenerator:
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
         # We check that the table and sequence haven't diverged.
-        self._sequence_gen.check_consistency(
-            db_conn, table=table, id_column=id_column, positive=positive
-        )
+        for table, _, id_column in tables:
+            self._sequence_gen.check_consistency(
+                db_conn,
+                table=table,
+                id_column=id_column,
+                stream_name=stream_name,
+                positive=positive,
+            )
 
         # This goes and fills out the above state from the database.
-        self._load_current_ids(db_conn, table, instance_column, id_column)
+        self._load_current_ids(db_conn, tables)
 
     def _load_current_ids(
-        self, db_conn, table: str, instance_column: str, id_column: str
+        self, db_conn, tables: List[Tuple[str, str, str]],
     ):
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
@@ -306,17 +315,22 @@ class MultiWriterIdGenerator:
             # We add a GREATEST here to ensure that the result is always
             # positive. (This can be a problem for e.g. backfill streams where
             # the server has never backfilled).
-            sql = """
-                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
-                FROM %(table)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "agg": "MAX" if self._positive else "-MIN",
-            }
-            cur.execute(sql)
-            (stream_id,) = cur.fetchone()
-            self._persisted_upto_position = stream_id
+            max_stream_id = 1
+            for table, _, id_column in tables:
+                sql = """
+                    SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                    FROM %(table)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "agg": "MAX" if self._positive else "-MIN",
+                }
+                cur.execute(sql)
+                (stream_id,) = cur.fetchone()
+
+                max_stream_id = max(max_stream_id, stream_id)
+
+            self._persisted_upto_position = max_stream_id
         else:
             # If we have a min_stream_id then we pull out everything greater
             # than it from the DB so that we can prefill
@@ -329,21 +343,28 @@ class MultiWriterIdGenerator:
             # stream positions table before restart (or the stream position
             # table otherwise got out of date).
 
-            sql = """
-                SELECT %(instance)s, %(id)s FROM %(table)s
-                WHERE ? %(cmp)s %(id)s
-            """ % {
-                "id": id_column,
-                "table": table,
-                "instance": instance_column,
-                "cmp": "<=" if self._positive else ">=",
-            }
-            cur.execute(sql, (min_stream_id * self._return_factor,))
-
             self._persisted_upto_position = min_stream_id
 
+            rows = []
+            for table, instance_column, id_column in tables:
+                sql = """
+                    SELECT %(instance)s, %(id)s FROM %(table)s
+                    WHERE ? %(cmp)s %(id)s
+                """ % {
+                    "id": id_column,
+                    "table": table,
+                    "instance": instance_column,
+                    "cmp": "<=" if self._positive else ">=",
+                }
+                cur.execute(sql, (min_stream_id * self._return_factor,))
+
+                rows.extend(cur)
+
+            # Sort so that we handle rows in order for each instance.
+            rows.sort()
+
             with self._lock:
-                for (instance, stream_id,) in cur:
+                for (instance, stream_id,) in rows:
                     stream_id = self._return_factor * stream_id
                     self._add_persisted_position(stream_id)
 
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 4386b6101e..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -15,9 +15,8 @@
 import abc
 import logging
 import threading
-from typing import Callable, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
 
-from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
@@ -25,6 +24,9 @@ from synapse.storage.engines import (
 )
 from synapse.storage.types import Connection, Cursor
 
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
+
 logger = logging.getLogger(__name__)
 
 
@@ -43,6 +45,21 @@ and run the following SQL:
 See docs/postgres.md for more information.
 """
 
+_INCONSISTENT_STREAM_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated stream position
+of '%(stream_name)s' in the 'stream_positions' table.
+
+This is likely a programming error and should be reported at
+https://github.com/matrix-org/synapse.
+
+A temporary workaround to fix this error is to shut down Synapse (including
+any and all workers) and run the following SQL:
+
+    DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
+
+This will need to be done every time the server is restarted.
+"""
+
 
 class SequenceGenerator(metaclass=abc.ABCMeta):
     """A class which generates a unique sequence of integers"""
@@ -53,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        """Get the next `n` IDs in the sequence"""
+        ...
+
+    @abc.abstractmethod
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
+        stream_name: Optional[str] = None,
         positive: bool = True,
     ):
         """Should be called during start up to test that the current value of
         the sequence is greater than or equal to the maximum ID in the table.
 
-        This is to handle various cases where the sequence value can get out
-        of sync with the table, e.g. if Synapse gets rolled back to a previous
+        This is to handle various cases where the sequence value can get out of
+        sync with the table, e.g. if Synapse gets rolled back to a previous
         version and the rolled forwards again.
+
+        If a stream name is given then this will check that any value in the
+        `stream_positions` table is less than or equal to the current sequence
+        value. If it isn't then it's likely that streams have been crossed
+        somewhere (e.g. two ID generators have the same stream name).
         """
         ...
 
@@ -88,11 +116,15 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
+        stream_name: Optional[str] = None,
         positive: bool = True,
     ):
+        """See SequenceGenerator.check_consistency for docstring.
+        """
+
         txn = db_conn.cursor(txn_name="sequence.check_consistency")
 
         # First we get the current max ID from the table.
@@ -116,6 +148,18 @@ class PostgresSequenceGenerator(SequenceGenerator):
             "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
         )
         last_value, is_called = txn.fetchone()
+
+        # If we have an associated stream check the stream_positions table.
+        max_in_stream_positions = None
+        if stream_name:
+            txn.execute(
+                "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
+                (stream_name,),
+            )
+            row = txn.fetchone()
+            if row:
+                max_in_stream_positions = row[0]
+
         txn.close()
 
         # If `is_called` is False then `last_value` is actually the value that
@@ -136,6 +180,14 @@ class PostgresSequenceGenerator(SequenceGenerator):
                 % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
             )
 
+        # If we have values in the stream positions table then they have to be
+        # less than or equal to `last_value`
+        if max_in_stream_positions and max_in_stream_positions > last_value:
+            raise IncorrectDatabaseSetup(
+                _INCONSISTENT_STREAM_ERROR
+                % {"seq": self._sequence_name, "stream_name": stream_name}
+            )
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
@@ -172,8 +224,24 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            first_id = self._current_max_id + 1
+            self._current_max_id += n
+            return [first_id + i for i in range(n)]
+
     def check_consistency(
-        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+        self,
+        db_conn: Connection,
+        table: str,
+        id_column: str,
+        stream_name: Optional[str] = None,
+        positive: bool = True,
     ):
         # There is nothing to do for in memory sequences
         pass
diff --git a/synapse/types.py b/synapse/types.py
index 3ab6bdbe06..eafe729dfe 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -37,6 +37,7 @@ from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
@@ -257,8 +258,13 @@ class DomainSpecificString(
 
     @classmethod
     def is_valid(cls: Type[DS], s: str) -> bool:
+        """Parses the input string and attempts to ensure it is valid."""
         try:
-            cls.from_string(s)
+            obj = cls.from_string(s)
+            # Apply additional validation to the domain. This is only done
+            # during  is_valid (and not part of from_string) since it is
+            # possible for invalid data to exist in room-state, etc.
+            parse_and_validate_server_name(obj.domain)
             return True
         except Exception:
             return False
@@ -349,15 +355,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile(
 )
 
 
-def map_username_to_mxid_localpart(username, case_sensitive=False):
+def map_username_to_mxid_localpart(
+    username: Union[str, bytes], case_sensitive: bool = False
+) -> str:
     """Map a username onto a string suitable for a MXID
 
     This follows the algorithm laid out at
     https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
 
     Args:
-        username (unicode|bytes): username to be mapped
-        case_sensitive (bool): true if TEST and test should be mapped
+        username: username to be mapped
+        case_sensitive: true if TEST and test should be mapped
             onto different mxids
 
     Returns:
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
 # limitations under the License.
 
 import collections
+import inspect
 import logging
 from contextlib import contextmanager
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Dict,
     Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
         raise StopIteration(self.value)
 
 
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
     """Convert a value to an awaitable if not already an awaitable.
     """
-
-    if hasattr(value, "__await__"):
+    if inspect.isawaitable(value):
+        assert isinstance(value, Awaitable)
         return value
 
     return DoneAwaitable(value)
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 601305487c..1adc92eb90 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -105,7 +105,7 @@ class DeferredCache(Generic[KT, VT]):
             keylen=keylen,
             cache_name=name,
             cache_type=cache_type,
-            size_callback=(lambda d: len(d)) if iterable else None,
+            size_callback=(lambda d: len(d) or 1) if iterable else None,
             metrics_collection_callback=metrics_cb,
             apply_cache_factor_from_config=apply_cache_factor_from_config,
         )  # type: LruCache[KT, VT]
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import inspect
 import logging
 
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -105,10 +105,7 @@ class Signal:
 
         async def do(observer):
             try:
-                result = observer(*args, **kwargs)
-                if inspect.isawaitable(result):
-                    result = await result
-                return result
+                return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
                 logger.warning(
                     "%s signal observer %s failed: %r", self.name, observer, e,
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..8d2411513f 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import heapq
 from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    Mapping,
+    Sequence,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from synapse.types import Collection
 
 T = TypeVar("T")
 
@@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
     If the input is empty, no chunks are returned.
     """
     return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+    nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+    """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+    For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+    """
+
+    # This is implemented by Kahn's algorithm.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph = {}  # type: Dict[T, Set[T]]
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in set(edges):
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+    heapq.heapify(zero_degree)
+
+    while zero_degree:
+        node = heapq.heappop(zero_degree)
+        yield node
+
+        for edge in reverse_graph.get(node, []):
+            if edge in degree_map:
+                degree_map[edge] -= 1
+                if degree_map[edge] == 0:
+                    heapq.heappush(zero_degree, edge)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ffdea0de8d..f4de6b9f54 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -108,7 +108,16 @@ class Measure:
     def __init__(self, clock, name):
         self.clock = clock
         self.name = name
-        parent_context = current_context()
+        curr_context = current_context()
+        if not curr_context:
+            logger.warning(
+                "Starting metrics collection %r from sentinel context: metrics will be lost",
+                name,
+            )
+            parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
         self._logging_context = LoggingContext(
             "Measure[%s]" % (self.name,), parent_context
         )
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 94b59afb38..09b094ded7 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,28 +15,57 @@
 
 import importlib
 import importlib.util
+import itertools
+from typing import Any, Iterable, Tuple, Type
+
+import jsonschema
 
 from synapse.config._base import ConfigError
+from synapse.config._util import json_error_to_config_error
 
 
-def load_module(provider):
+def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
     """ Loads a synapse module with its config
-    Take a dict with keys 'module' (the module name) and 'config'
-    (the config dict).
+
+    Args:
+        provider: a dict with keys 'module' (the module name) and 'config'
+           (the config dict).
+        config_path: the path within the config file. This will be used as a basis
+           for any error message.
 
     Returns
         Tuple of (provider class, parsed config object)
     """
+
+    modulename = provider.get("module")
+    if not isinstance(modulename, str):
+        raise ConfigError(
+            "expected a string", path=itertools.chain(config_path, ("module",))
+        )
+
     # We need to import the module, and then pick the class out of
     # that, so we split based on the last dot.
-    module, clz = provider["module"].rsplit(".", 1)
+    module, clz = modulename.rsplit(".", 1)
     module = importlib.import_module(module)
     provider_class = getattr(module, clz)
 
+    # Load the module config. If None, pass an empty dictionary instead
+    module_config = provider.get("config") or {}
     try:
-        provider_config = provider_class.parse_config(provider.get("config"))
+        provider_config = provider_class.parse_config(module_config)
+    except jsonschema.ValidationError as e:
+        raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
+    except ConfigError as e:
+        raise _wrap_config_error(
+            "Failed to parse config for module %r" % (modulename,),
+            prefix=itertools.chain(config_path, ("config",)),
+            e=e,
+        )
     except Exception as e:
-        raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
+        raise ConfigError(
+            "Failed to parse config for module %r" % (modulename,),
+            path=itertools.chain(config_path, ("config",)),
+        ) from e
 
     return provider_class, provider_config
 
@@ -56,3 +85,27 @@ def load_python_module(location: str):
     mod = importlib.util.module_from_spec(spec)
     spec.loader.exec_module(mod)  # type: ignore
     return mod
+
+
+def _wrap_config_error(
+    msg: str, prefix: Iterable[str], e: ConfigError
+) -> "ConfigError":
+    """Wrap a relative ConfigError with a new path
+
+    This is useful when we have a ConfigError with a relative path due to a problem
+    parsing part of the config, and we now need to set it in context.
+    """
+    path = prefix
+    if e.path:
+        path = itertools.chain(prefix, e.path)
+
+    e1 = ConfigError(msg, path)
+
+    # ideally we would set the 'cause' of the new exception to the original exception;
+    # however now that we have merged the path into our own, the stringification of
+    # e will be incorrect, so instead we create a new exception with just the "msg"
+    # part.
+
+    e1.__cause__ = Exception(e.msg)
+    e1.__cause__.__cause__ = e.__cause__
+    return e1
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 61d96a6c28..f8038bf861 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -18,6 +18,7 @@ import random
 import re
 import string
 from collections.abc import Iterable
+from typing import Optional, Tuple
 
 from synapse.api.errors import Codes, SynapseError
 
@@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
 client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
 
+# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
+# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
+# says "there is no grammar for media ids"
+#
+# The server_name part of this is purposely lax: use parse_and_validate_mxc for
+# additional validation.
+#
+MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
+
 # random_string and random_string_with_symbols are used for a range of things,
 # some cryptographically important, some less so. We use SystemRandom to make sure
 # we get cryptographically-secure randoms.
@@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret):
         )
 
 
+def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == "]":
+            # ipv6 literal, hopefully
+            return server_name, None
+
+        domain_port = server_name.rsplit(":", 1)
+        domain = domain_port[0]
+        port = int(domain_port[1]) if domain_port[1:] else None
+        return domain, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
+
+
+def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts and do some basic validation.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    host, port = parse_server_name(server_name)
+
+    # these tests don't need to be bulletproof as we'll find out soon enough
+    # if somebody is giving us invalid data. What we *do* need is to be sure
+    # that nobody is sneaking IP literals in that look like hostnames, etc.
+
+    # look for ipv6 literals
+    if host[0] == "[":
+        if host[-1] != "]":
+            raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
+        return host, port
+
+    # otherwise it should only be alphanumerics.
+    if not VALID_HOST_REGEX.match(host):
+        raise ValueError(
+            "Server name '%s' contains invalid characters" % (server_name,)
+        )
+
+    return host, port
+
+
+def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
+    """Parse the given string as an MXC URI
+
+    Checks that the "server name" part is a valid server name
+
+    Args:
+        mxc: the (alleged) MXC URI to be checked
+    Returns:
+        hostname, port, media id
+    Raises:
+        ValueError if the URI cannot be parsed
+    """
+    m = MXC_REGEX.match(mxc)
+    if not m:
+        raise ValueError("mxc URI %r did not match expected format" % (mxc,))
+    server_name = m.group(1)
+    media_id = m.group(2)
+    host, port = parse_and_validate_server_name(server_name)
+    return host, port, media_id
+
+
 def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
     """If iterable has maxitems or fewer, return the stringification of a list
     containing those items.
@@ -75,3 +167,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
     if len(items) <= maxitems:
         return str(items)
     return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
+
+
+def strtobool(val: str) -> bool:
+    """Convert a string representation of truth to True or False
+
+    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+    are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
+    'val' is anything else.
+
+    This is lifted from distutils.util.strtobool, with the exception that it actually
+    returns a bool, rather than an int.
+    """
+    val = val.lower()
+    if val in ("y", "yes", "t", "true", "on", "1"):
+        return True
+    elif val in ("n", "no", "f", "false", "off", "0"):
+        return False
+    else:
+        raise ValueError("invalid truth value %r" % (val,))
diff --git a/synapse/util/templates.py b/synapse/util/templates.py
new file mode 100644
index 0000000000..392dae4a40
--- /dev/null
+++ b/synapse/util/templates.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for dealing with jinja2 templates"""
+
+import time
+import urllib.parse
+from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
+
+import jinja2
+
+if TYPE_CHECKING:
+    from synapse.config.homeserver import HomeServerConfig
+
+
+def build_jinja_env(
+    template_search_directories: Iterable[str],
+    config: "HomeServerConfig",
+    autoescape: Union[bool, Callable[[str], bool], None] = None,
+) -> jinja2.Environment:
+    """Set up a Jinja2 environment to load templates from the given search path
+
+    The returned environment defines the following filters:
+        - format_ts: formats timestamps as strings in the server's local timezone
+             (XXX: why is that useful??)
+        - mxc_to_http: converts mxc: uris to http URIs. Args are:
+             (uri, width, height, resize_method="crop")
+
+    and the following global variables:
+        - server_name: matrix server name
+
+    Args:
+        template_search_directories: directories to search for templates
+
+        config: homeserver config, for things like `server_name` and `public_baseurl`
+
+        autoescape: whether template variables should be autoescaped. bool, or
+           a function mapping from template name to bool. Defaults to escaping templates
+           whose names end in .html, .xml or .htm.
+
+    Returns:
+        jinja environment
+    """
+
+    if autoescape is None:
+        autoescape = jinja2.select_autoescape()
+
+    loader = jinja2.FileSystemLoader(template_search_directories)
+    env = jinja2.Environment(loader=loader, autoescape=autoescape)
+
+    # Update the environment with our custom filters
+    env.filters.update(
+        {
+            "format_ts": _format_ts_filter,
+            "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl),
+        }
+    )
+
+    # common variables for all templates
+    env.globals.update({"server_name": config.server_name})
+
+    return env
+
+
+def _create_mxc_to_http_filter(
+    public_baseurl: Optional[str],
+) -> Callable[[str, int, int, str], str]:
+    """Create and return a jinja2 filter that converts MXC urls to HTTP
+
+    Args:
+        public_baseurl: The public, accessible base URL of the homeserver
+    """
+
+    def mxc_to_http_filter(
+        value: str, width: int, height: int, resize_method: str = "crop"
+    ) -> str:
+        if not public_baseurl:
+            raise RuntimeError(
+                "public_baseurl must be set in the homeserver config to convert MXC URLs to HTTP URLs."
+            )
+
+        if value[0:6] != "mxc://":
+            return ""
+
+        server_and_media_id = value[6:]
+        fragment = None
+        if "#" in server_and_media_id:
+            server_and_media_id, fragment = server_and_media_id.split("#", 1)
+            fragment = "#" + fragment
+
+        params = {"width": width, "height": height, "method": resize_method}
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            public_baseurl,
+            server_and_media_id,
+            urllib.parse.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
+
+
+def _format_ts_filter(value: int, format: str):
+    return time.strftime(format, time.localtime(value / 1000))
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 527365498e..ec50e7e977 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -12,11 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import operator
 
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import (
+    AccountDataTypes,
+    EventTypes,
+    HistoryVisibility,
+    Membership,
+)
 from synapse.events.utils import prune_event
 from synapse.storage import Storage
 from synapse.storage.state import StateFilter
@@ -25,7 +29,12 @@ from synapse.types import get_domain_from_id
 logger = logging.getLogger(__name__)
 
 
-VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined")
+VISIBILITY_PRIORITY = (
+    HistoryVisibility.WORLD_READABLE,
+    HistoryVisibility.SHARED,
+    HistoryVisibility.INVITED,
+    HistoryVisibility.JOINED,
+)
 
 
 MEMBERSHIP_PRIORITY = (
@@ -116,7 +125,7 @@ async def filter_events_for_client(
         # see events in the room at that point in the DAG, and that shouldn't be decided
         # on those checks.
         if filter_send_to_client:
-            if event.type == "org.matrix.dummy_event":
+            if event.type == EventTypes.Dummy:
                 return None
 
             if not event.is_state() and event.sender in ignore_list:
@@ -150,12 +159,14 @@ async def filter_events_for_client(
         # get the room_visibility at the time of the event.
         visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
         if visibility_event:
-            visibility = visibility_event.content.get("history_visibility", "shared")
+            visibility = visibility_event.content.get(
+                "history_visibility", HistoryVisibility.SHARED
+            )
         else:
-            visibility = "shared"
+            visibility = HistoryVisibility.SHARED
 
         if visibility not in VISIBILITY_PRIORITY:
-            visibility = "shared"
+            visibility = HistoryVisibility.SHARED
 
         # Always allow history visibility events on boundaries. This is done
         # by setting the effective visibility to the least restrictive
@@ -165,7 +176,7 @@ async def filter_events_for_client(
             prev_visibility = prev_content.get("history_visibility", None)
 
             if prev_visibility not in VISIBILITY_PRIORITY:
-                prev_visibility = "shared"
+                prev_visibility = HistoryVisibility.SHARED
 
             new_priority = VISIBILITY_PRIORITY.index(visibility)
             old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
@@ -210,17 +221,17 @@ async def filter_events_for_client(
 
         # otherwise, it depends on the room visibility.
 
-        if visibility == "joined":
+        if visibility == HistoryVisibility.JOINED:
             # we weren't a member at the time of the event, so we can't
             # see this event.
             return None
 
-        elif visibility == "invited":
+        elif visibility == HistoryVisibility.INVITED:
             # user can also see the event if they were *invited* at the time
             # of the event.
             return event if membership == Membership.INVITE else None
 
-        elif visibility == "shared" and is_peeking:
+        elif visibility == HistoryVisibility.SHARED and is_peeking:
             # if the visibility is shared, users cannot see the event unless
             # they have *subequently* joined the room (or were members at the
             # time, of course)
@@ -284,8 +295,10 @@ async def filter_events_for_server(
     def check_event_is_visible(event, state):
         history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
         if history:
-            visibility = history.content.get("history_visibility", "shared")
-            if visibility in ["invited", "joined"]:
+            visibility = history.content.get(
+                "history_visibility", HistoryVisibility.SHARED
+            )
+            if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]:
                 # We now loop through all state events looking for
                 # membership states for the requesting server to determine
                 # if the server is either in the room or has been invited
@@ -305,7 +318,7 @@ async def filter_events_for_server(
                     if memtype == Membership.JOIN:
                         return True
                     elif memtype == Membership.INVITE:
-                        if visibility == "invited":
+                        if visibility == HistoryVisibility.INVITED:
                             return True
                 else:
                     # server has no users in the room: redact
@@ -336,7 +349,8 @@ async def filter_events_for_server(
     else:
         event_map = await storage.main.get_events(visibility_ids)
         all_open = all(
-            e.content.get("history_visibility") in (None, "shared", "world_readable")
+            e.content.get("history_visibility")
+            in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
             for e in event_map.values()
         )
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index c98ae75974..279c94a03d 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -16,8 +16,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from mock import Mock
-
 import jsonschema
 
 from twisted.internet import defer
@@ -28,7 +26,7 @@ from synapse.api.filtering import Filter
 from synapse.events import make_event_from_dict
 
 from tests import unittest
-from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver
+from tests.utils import setup_test_homeserver
 
 user_localpart = "test_user"
 
@@ -42,19 +40,9 @@ def MockEvent(**kwargs):
 
 
 class FilteringTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
     def setUp(self):
-        self.mock_federation_resource = MockHttpResource()
-
-        self.mock_http_client = Mock(spec=[])
-        self.mock_http_client.put_json = DeferredMockCallable()
-
-        hs = yield setup_test_homeserver(
-            self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
-        )
-
+        hs = setup_test_homeserver(self.addCleanup)
         self.filtering = hs.get_filtering()
-
         self.datastore = hs.get_datastore()
 
     def test_errors_on_invalid_filters(self):
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 40abe9d72d..e0ca288829 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -23,7 +23,7 @@ class FrontendProxyTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=GenericWorkerServer
+            federation_http_client=None, homeserver_to_use=GenericWorkerServer
         )
 
         return hs
@@ -57,7 +57,7 @@ class FrontendProxyTests(HomeserverTestCase):
         self.assertEqual(len(self.reactor.tcpServers), 1)
         site = self.reactor.tcpServers[0][1]
 
-        _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+        channel = make_request(self.reactor, site, "PUT", "presence/a/status")
 
         # 400 + unrecognised, because nothing is registered
         self.assertEqual(channel.code, 400)
@@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
         self.assertEqual(len(self.reactor.tcpServers), 1)
         site = self.reactor.tcpServers[0][1]
 
-        _, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+        channel = make_request(self.reactor, site, "PUT", "presence/a/status")
 
         # 401, because the stub servlet still checks authentication
         self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index ea3be95cf1..467033e201 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
 class FederationReaderOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=GenericWorkerServer
+            federation_http_client=None, homeserver_to_use=GenericWorkerServer
         )
         return hs
 
@@ -73,7 +73,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
                 return
             raise
 
-        _, channel = make_request(
+        channel = make_request(
             self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
         )
 
@@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
 class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            http_client=None, homeserver_to_use=SynapseHomeServer
+            federation_http_client=None, homeserver_to_use=SynapseHomeServer
         )
         return hs
 
@@ -121,7 +121,7 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
                 return
             raise
 
-        _, channel = make_request(
+        channel = make_request(
             self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
         )
 
diff --git a/tests/config/test_util.py b/tests/config/test_util.py
new file mode 100644
index 0000000000..10363e3765
--- /dev/null
+++ b/tests/config/test_util.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config import ConfigError
+from synapse.config._util import validate_config
+
+from tests.unittest import TestCase
+
+
+class ValidateConfigTestCase(TestCase):
+    """Test cases for synapse.config._util.validate_config"""
+
+    def test_bad_object_in_array(self):
+        """malformed objects within an array should be validated correctly"""
+
+        # consider a structure:
+        #
+        # array_of_objs:
+        #   - r: 1
+        #     foo: 2
+        #
+        #   - r: 2
+        #     bar: 3
+        #
+        # ... where each entry must contain an "r": check that the path
+        # to the required item is correclty reported.
+
+        schema = {
+            "type": "object",
+            "properties": {
+                "array_of_objs": {
+                    "type": "array",
+                    "items": {"type": "object", "required": ["r"]},
+                },
+            },
+        }
+
+        with self.assertRaises(ConfigError) as c:
+            validate_config(schema, {"array_of_objs": [{}]}, ("base",))
+
+        self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 697916a019..1d65ea2f9c 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         return val
 
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock()
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
@@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Tests that we correctly handle key requests for keys we've stored
         with a null `ts_valid_until_ms`
         """
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
 
         kr = keyring.Keyring(
@@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 }
             }
 
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
@@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 }
             }
 
-        mock_fetcher1 = keyring.KeyFetcher()
+        mock_fetcher1 = Mock()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
-        mock_fetcher2 = keyring.KeyFetcher()
+        mock_fetcher2 = Mock()
         mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
 
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
-        hs = self.setup_test_homeserver(http_client=self.http_client)
+        hs = self.setup_test_homeserver(federation_http_client=self.http_client)
         return hs
 
     def test_get_keys_from_server(self):
@@ -395,7 +395,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
             }
         ]
 
-        return self.setup_test_homeserver(http_client=self.http_client, config=config)
+        return self.setup_test_homeserver(
+            federation_http_client=self.http_client, config=config
+        )
 
     def build_perspectives_response(
         self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index c1274c14af..8ba36c6074 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -34,11 +34,17 @@ def MockEvent(**kwargs):
 
 
 class PruneEventTestCase(unittest.TestCase):
-    """ Asserts that a new event constructed with `evdict` will look like
-    `matchdict` when it is redacted. """
-
     def run_test(self, evdict, matchdict, **kwargs):
-        self.assertEquals(
+        """
+        Asserts that a new event constructed with `evdict` will look like
+        `matchdict` when it is redacted.
+
+        Args:
+             evdict: The dictionary to build the event from.
+             matchdict: The expected resulting dictionary.
+             kwargs: Additional keyword arguments used to create the event.
+        """
+        self.assertEqual(
             prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
         )
 
@@ -55,54 +61,80 @@ class PruneEventTestCase(unittest.TestCase):
         )
 
     def test_basic_keys(self):
+        """Ensure that the keys that should be untouched are kept."""
+        # Note that some of the values below don't really make sense, but the
+        # pruning of events doesn't worry about the values of any fields (with
+        # the exception of the content field).
         self.run_test(
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "content": {"other_key": "foo"},
+                "hashes": "hashes",
+                "signatures": {"domain": {"algo:1": "sigs"}},
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
+                # Also include a key that should be removed.
+                "other_key": "foo",
             },
             {
+                "event_id": "$3:domain",
                 "type": "A",
                 "room_id": "!1:domain",
                 "sender": "@2:domain",
-                "event_id": "$3:domain",
+                "state_key": "B",
+                "hashes": "hashes",
+                "depth": 4,
+                "prev_events": "prev_events",
+                "prev_state": "prev_state",
+                "auth_events": "auth_events",
                 "origin": "domain",
+                "origin_server_ts": 1234,
+                "membership": "join",
                 "content": {},
-                "signatures": {},
+                "signatures": {"domain": {"algo:1": "sigs"}},
                 "unsigned": {},
             },
         )
 
-    def test_unsigned_age_ts(self):
+        # As of MSC2176 we now redact the membership and prev_states keys.
         self.run_test(
-            {"type": "B", "event_id": "$test:domain", "unsigned": {"age_ts": 20}},
-            {
-                "type": "B",
-                "event_id": "$test:domain",
-                "content": {},
-                "signatures": {},
-                "unsigned": {"age_ts": 20},
-            },
+            {"type": "A", "prev_state": "prev_state", "membership": "join"},
+            {"type": "A", "content": {}, "signatures": {}, "unsigned": {}},
+            room_version=RoomVersions.MSC2176,
         )
 
+    def test_unsigned(self):
+        """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
         self.run_test(
             {
                 "type": "B",
                 "event_id": "$test:domain",
-                "unsigned": {"other_key": "here"},
+                "unsigned": {
+                    "age_ts": 20,
+                    "replaces_state": "$test2:domain",
+                    "other_key": "foo",
+                },
             },
             {
                 "type": "B",
                 "event_id": "$test:domain",
                 "content": {},
                 "signatures": {},
-                "unsigned": {},
+                "unsigned": {"age_ts": 20, "replaces_state": "$test2:domain"},
             },
         )
 
     def test_content(self):
+        """The content dictionary should be stripped in most cases."""
         self.run_test(
             {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
             {
@@ -114,11 +146,35 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # Some events keep a single content key/value.
+        EVENT_KEEP_CONTENT_KEYS = [
+            ("member", "membership", "join"),
+            ("join_rules", "join_rule", "invite"),
+            ("history_visibility", "history_visibility", "shared"),
+        ]
+        for event_type, key, value in EVENT_KEEP_CONTENT_KEYS:
+            self.run_test(
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value, "other_key": "foo"},
+                },
+                {
+                    "type": "m.room." + event_type,
+                    "event_id": "$test:domain",
+                    "content": {key: value},
+                    "signatures": {},
+                    "unsigned": {},
+                },
+            )
+
+    def test_create(self):
+        """Create events are partially redacted until MSC2176."""
         self.run_test(
             {
                 "type": "m.room.create",
                 "event_id": "$test:domain",
-                "content": {"creator": "@2:domain", "other_field": "here"},
+                "content": {"creator": "@2:domain", "other_key": "foo"},
             },
             {
                 "type": "m.room.create",
@@ -129,6 +185,68 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
+        # After MSC2176, create events get nothing redacted.
+        self.run_test(
+            {"type": "m.room.create", "content": {"not_a_real_key": True}},
+            {
+                "type": "m.room.create",
+                "content": {"not_a_real_key": True},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
+    def test_power_levels(self):
+        """Power level events keep a variety of content keys."""
+        self.run_test(
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    "invite": 3,
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                    "other_key": 8,
+                },
+            },
+            {
+                "type": "m.room.power_levels",
+                "event_id": "$test:domain",
+                "content": {
+                    "ban": 1,
+                    "events": {"m.room.name": 100},
+                    "events_default": 2,
+                    # Note that invite is not here.
+                    "kick": 4,
+                    "redact": 5,
+                    "state_default": 6,
+                    "users": {"@admin:domain": 100},
+                    "users_default": 7,
+                },
+                "signatures": {},
+                "unsigned": {},
+            },
+        )
+
+        # After MSC2176, power levels events keep the invite key.
+        self.run_test(
+            {"type": "m.room.power_levels", "content": {"invite": 75}},
+            {
+                "type": "m.room.power_levels",
+                "content": {"invite": 75},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
     def test_alias_event(self):
         """Alias events have special behavior up through room version 6."""
         self.run_test(
@@ -146,8 +264,7 @@ class PruneEventTestCase(unittest.TestCase):
             },
         )
 
-    def test_msc2432_alias_event(self):
-        """After MSC2432, alias events have no special behavior."""
+        # After MSC2432, alias events have no special behavior.
         self.run_test(
             {"type": "m.room.aliases", "content": {"aliases": ["test"]}},
             {
@@ -159,6 +276,32 @@ class PruneEventTestCase(unittest.TestCase):
             room_version=RoomVersions.V6,
         )
 
+    def test_redacts(self):
+        """Redaction events have no special behaviour until MSC2174/MSC2176."""
+
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.V6,
+        )
+
+        # After MSC2174, redaction events keep the redacts content key.
+        self.run_test(
+            {"type": "m.room.redaction", "content": {"redacts": "$test2:domain"}},
+            {
+                "type": "m.room.redaction",
+                "content": {"redacts": "$test2:domain"},
+                "signatures": {},
+                "unsigned": {},
+            },
+            room_version=RoomVersions.MSC2176,
+        )
+
 
 class SerializeEventTestCase(unittest.TestCase):
     def serialize(self, ev, fields):
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0187f56e21..9ccd2d76b8 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -48,7 +48,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         )
 
         # Get the room complexity
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
         self.assertEquals(200, channel.code)
@@ -60,7 +60,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
 
         # Get the room complexity again -- make sure it's our artificial value
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
         self.assertEquals(200, channel.code)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 3009fbb6c4..cfeccc0577 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -46,7 +46,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
 
         "/get_missing_events/(?P<room_id>[^/]*)/?"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
             query_content,
@@ -95,7 +95,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
         room_1 = self.helper.create_room_as(u1, tok=u1_token)
         self.inject_room_member(room_1, "@user:other.example.com", "join")
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -127,7 +127,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
 
         room_1 = self.helper.create_room_as(u1, tok=u1_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
         self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/federation/transport/__init__.py b/tests/federation/transport/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/federation/transport/__init__.py
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index f9e3c7a51f..85500e169c 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,38 +13,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
 from tests import unittest
 from tests.unittest import override_config
 
 
-class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
-        class Authenticator:
-            def authenticate_request(self, request, content):
-                return defer.succeed("otherserver.nottld")
-
-        ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
-        server.register_servlets(
-            homeserver, self.resource, Authenticator(), ratelimiter
-        )
-
+class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
     @override_config({"allow_public_rooms_over_federation": False})
     def test_blocked_public_room_list_over_federation(self):
-        request, channel = self.make_request(
-            "GET", "/_matrix/federation/v1/publicRooms"
+        """Test that unauthenticated requests to the public rooms directory 403 when
+        allow_public_rooms_over_federation is False.
+        """
+        channel = self.make_request(
+            "GET",
+            "/_matrix/federation/v1/publicRooms",
+            federation_auth_origin=b"example.com",
         )
         self.assertEquals(403, channel.code)
 
     @override_config({"allow_public_rooms_over_federation": True})
     def test_open_public_room_list_over_federation(self):
-        request, channel = self.make_request(
-            "GET", "/_matrix/federation/v1/publicRooms"
+        """Test that unauthenticated requests to the public rooms directory 200 when
+        allow_public_rooms_over_federation is True.
+        """
+        channel = self.make_request(
+            "GET",
+            "/_matrix/federation/v1/publicRooms",
+            federation_auth_origin=b"example.com",
         )
         self.assertEquals(200, channel.code)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..7baf224f7e
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@
+#  Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+        cas_config = {
+            "enabled": True,
+            "server_url": SERVER_URL,
+            "service_url": BASE_URL,
+        }
+        config["cas_config"] = cas_config
+
+        return config
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+
+        self.handler = hs.get_cas_handler()
+
+        # Reduce the number of attempts when generating MXIDs.
+        sso_handler = hs.get_sso_handler()
+        sso_handler._MAP_USERNAME_RETRIES = 3
+
+        return hs
+
+    def test_map_cas_user_to_user(self):
+        """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
+    def test_map_cas_user_to_existing_user(self):
+        """Existing users can log in with CAS account."""
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.register_user(user_id="@test_user:test", password_hash=None)
+        )
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # Map a user via SSO.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=False
+        )
+
+        # Subsequent calls should map to the same mxid.
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=False
+        )
+
+    def test_map_cas_user_to_invalid_localpart(self):
+        """CAS automaps invalid characters to base-64 encoding."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("föö", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+        )
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 875aaec2c6..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -27,7 +27,7 @@ user2 = "@theresa:bbb"
 
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.store = hs.get_datastore()
         return hs
@@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ee6ef5e6fa..a39f898608 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.mock_registry.register_query_handler = register_query_handler
 
         hs = self.setup_test_homeserver(
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_registry=self.mock_registry,
         )
@@ -407,7 +405,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
     def test_denied(self):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             b"directory/room/%23test%3Atest",
             ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -417,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
     def test_allowed(self):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             b"directory/room/%23unofficial_test%3Atest",
             ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -433,7 +431,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -448,7 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.directory_handler.enable_room_list_search = True
 
         # Room list is enabled so we should get some results
-        request, channel = self.make_request("GET", b"publicRooms")
+        channel = self.make_request("GET", b"publicRooms")
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) > 0)
 
@@ -456,13 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.directory_handler.enable_room_list_search = False
 
         # Room list disabled so we should get no results
-        request, channel = self.make_request("GET", b"publicRooms")
+        channel = self.make_request("GET", b"publicRooms")
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) == 0)
 
         # Room list disabled so we shouldn't be allowed to publish rooms
         room_id = self.helper.create_room_as(self.user_id)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf866dacf3..983e368592 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
 from unittest import TestCase
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
 from synapse.federation.federation_base import event_from_pdu_json
@@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(http_client=None)
+        hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastore()
         return hs
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             room_version,
         )
 
-        with LoggingContext(request="send_rejected"):
+        with LoggingContext("send_rejected"):
             d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
         self.get_success(d)
 
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             room_version,
         )
 
-        with LoggingContext(request="send_rejected"):
+        with LoggingContext("send_rejected"):
             d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
         self.get_success(d)
 
@@ -191,6 +191,50 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invite_by_user_ratelimit(self):
+        """Tests that invites from federation to a particular user are
+        actually rate-limited.
+        """
+        other_server = "otherserver"
+        other_user = "@otheruser:" + other_server
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        def create_invite():
+            room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+            room_version = self.get_success(self.store.get_room_version(room_id))
+            return event_from_pdu_json(
+                {
+                    "type": EventTypes.Member,
+                    "content": {"membership": "invite"},
+                    "room_id": room_id,
+                    "sender": other_user,
+                    "state_key": "@user:test",
+                    "depth": 32,
+                    "prev_events": [],
+                    "auth_events": [],
+                    "origin_server_ts": self.clock.time_msec(),
+                },
+                room_version,
+            )
+
+        for i in range(3):
+            event = create_invite()
+            self.get_success(
+                self.handler.on_invite_request(other_server, event, event.room_version,)
+            )
+
+        event = create_invite()
+        self.get_failure(
+            self.handler.on_invite_request(other_server, event, event.room_version,),
+            exc=LimitExceededError,
+        )
+
     def _build_and_send_join_event(self, other_server, other_user, room_id):
         join_event = self.get_success(
             self.handler.on_make_join_request(other_server, room_id, other_user)
@@ -198,7 +242,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # the auth code requires that a signature exists, but doesn't check that
         # signature... go figure.
         join_event.signatures[other_server] = {"x": "y"}
-        with LoggingContext(request="send_join"):
+        with LoggingContext("send_join"):
             d = run_in_background(
                 self.handler.on_send_join_request, other_server, join_event
             )
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index af42775815..f955dfa490 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -206,7 +206,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
 
         # Redaction of event should fail.
         path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", path, content={}, access_token=self.access_token
         )
         self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..ad20400b1d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,32 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from typing import Optional
 from urllib.parse import parse_qs, urlparse
 
-from mock import Mock, patch
+from mock import ANY, Mock, patch
 
-import attr
 import pymacaroons
 
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
 from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
 from synapse.types import UserID
 
+from tests.test_utils import FakeResponse, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+try:
+    import authlib  # noqa: F401
 
-@attr.s
-class FakeResponse:
-    code = attr.ib()
-    body = attr.ib()
-    phrase = attr.ib()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
+    HAS_OIDC = True
+except ImportError:
+    HAS_OIDC = False
 
 
 # These are a few constants that are used as config parameters in the tests.
@@ -46,7 +40,7 @@ ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
 CLIENT_SECRET = "test-client-secret"
 BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
 AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -64,17 +58,14 @@ COMMON_CONFIG = {
 }
 
 
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
     @staticmethod
     def parse_config(config):
         return
 
+    def __init__(self, config):
+        pass
+
     def get_remote_user_id(self, userinfo):
         return userinfo["sub"]
 
@@ -97,16 +88,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-def simple_async_mock(return_value=None, raises=None):
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 async def get_json(url):
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
@@ -127,6 +108,9 @@ async def get_json(url):
 
 
 class OidcHandlerTestCase(HomeserverTestCase):
+    if not HAS_OIDC:
+        skip = "requires OIDC"
+
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
@@ -155,6 +139,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
 
         self.handler = hs.get_oidc_handler()
+        self.provider = self.handler._providers["oidc"]
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
@@ -166,27 +151,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.handler._provider_metadata, values)
+        return patch.dict(self.provider._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
+        self.render_error.assert_called_once()
         args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
             self.assertEqual(args[2], error_description)
         # Reset the render_error mock
         self.render_error.reset_mock()
+        return args
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
-        self.assertEqual(self.handler._callback_url, CALLBACK_URL)
-        self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
-        self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+        self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+        self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+        self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {"discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
-        metadata = self.get_success(self.handler.load_metadata())
+        metadata = self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
 
         self.assertEqual(metadata.issuer, ISSUER)
@@ -198,47 +185,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # subsequent calls should be cached
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
-        jwks = self.get_success(self.handler.load_jwks())
+        jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.assertEqual(jwks, {"keys": []})
 
         # subsequent calls should be cached…
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks())
+        self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_not_called()
 
         # …unless forced
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks(force=True))
+        self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
         with self.metadata_edit({"jwks_uri": None}):
-            self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+            self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
-        self.handler._scopes = []  # not asking the openid scope
+        self.provider._scopes = []  # not asking the openid scope
         self.http_client.get_json.reset_mock()
-        jwks = self.get_success(self.handler.load_jwks(force=True))
+        jwks = self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
-        h = self.handler
+        h = self.provider
 
         # Default test config does not throw
         h._validate_metadata()
@@ -317,13 +304,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.handler._validate_metadata()
+            self.provider._validate_metadata()
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["addCookie"])
         url = self.get_success(
-            self.handler.handle_redirect_request(req, b"http://client/redirect")
+            self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
         url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -347,14 +334,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # For some reason, call.args does not work with python3.5
         args = calls[0][0]
         kwargs = calls[0][1]
-        self.assertEqual(args[0], COOKIE_NAME)
-        self.assertEqual(kwargs["path"], COOKIE_PATH)
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        self.assertEqual(args[0], b"oidc_session")
+        self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
         cookie = args[1]
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._get_value_from_macaroon(macaroon, "state")
-        nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
-        redirect = self.handler._get_value_from_macaroon(
+        state = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "state"
+        )
+        nonce = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "nonce"
+        )
+        redirect = self.handler._token_generator._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
 
@@ -384,31 +378,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
          - when the userinfo fetching fails
          - when the code exchange fails
         """
+
+        # ensure that we are correctly testing the fallback when "get_extra_attributes"
+        # is not implemented.
+        mapping_provider = self.provider._user_mapping_provider
+        with self.assertRaises(AttributeError):
+            _ = mapping_provider.get_extra_attributes
+
         token = {
             "type": "bearer",
             "id_token": "id_token",
             "access_token": "access_token",
         }
+        username = "bar"
         userinfo = {
             "sub": "foo",
-            "preferred_username": "bar",
+            "username": username,
         }
-        user_id = "@foo:domain.org"
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        expected_user_id = "@%s:%s" % (username, self.hs.hostname)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         code = "code"
         state = "state"
@@ -416,74 +408,61 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+        request = _build_callback_request(
+            code, state, session, user_agent=user_agent, ip_address=ip_address
         )
 
-        request.args = {}
-        request.args[b"code"] = [code.encode("utf-8")]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = ip_address
-        request.get_user_agent.return_value = user_agent
-
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None, new_user=True
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
-        )
-        self.handler._fetch_userinfo.assert_not_called()
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+        self.provider._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
-        self.handler._map_userinfo_to_user = simple_async_mock(
-            raises=MappingException()
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("mapping_error")
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+        with patch.object(
+            self.provider,
+            "_remote_id_from_userinfo",
+            new=Mock(side_effect=MappingException()),
+        ):
+            self.get_success(self.handler.handle_oidc_callback(request))
+            self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.handler._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        self.handler._auth_handler.complete_sso_login.reset_mock()
-        self.handler._exchange_code.reset_mock()
-        self.handler._parse_id_token.reset_mock()
-        self.handler._map_userinfo_to_user.reset_mock()
-        self.handler._fetch_userinfo.reset_mock()
+        auth_handler.complete_sso_login.reset_mock()
+        self.provider._exchange_code.reset_mock()
+        self.provider._parse_id_token.reset_mock()
+        self.provider._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
-        self.handler._scopes = []  # do not ask the "openid" scope
+        self.provider._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
-        )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None, new_user=False
         )
-        self.handler._fetch_userinfo.assert_called_once_with(token)
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_not_called()
+        self.provider._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
-        self.handler._exchange_code = simple_async_mock(
+        from synapse.handlers.oidc_handler import OidcError
+
+        self.provider._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -513,11 +492,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_session")
 
         # Mismatching session
-        session = self.handler._generate_oidc_session_token(
-            state="state",
-            nonce="nonce",
-            client_redirect_url="http://client/redirect",
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -541,7 +517,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
         )
         code = "code"
-        ret = self.get_success(self.handler._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code))
         kwargs = self.http_client.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -563,7 +539,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 body=b'{"error": "foo", "error_description": "bar"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        from synapse.handlers.oidc_handler import OidcError
+
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
 
@@ -573,7 +551,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
@@ -585,14 +563,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             )
         )
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
@@ -601,7 +579,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
 
     @override_config(
@@ -624,72 +602,56 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         userinfo = {
             "sub": "foo",
+            "username": "foo",
             "phone": "1234567",
         }
-        user_id = "@foo:domain.org"
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
         )
-
-        request.args = {}
-        request.args[b"code"] = [b"code"]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = "10.0.0.1"
-        request.get_user_agent.return_value = "Browser"
+        request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {"phone": "1234567"},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@foo:test",
+            request,
+            client_redirect_url,
+            {"phone": "1234567"},
+            new_user=True,
         )
 
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         userinfo = {
             "sub": "test_user",
             "username": "test_user",
         }
-        # The token doesn't matter with the default user mapping provider.
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", ANY, ANY, None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user_2:test", ANY, ANY, None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user_2:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastore()
@@ -698,14 +660,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
-        self.assertEqual(
-            str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error",
+            "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
     @override_config({"oidc_config": {"allow_existing_users": True}})
@@ -717,26 +676,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -747,13 +706,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -770,14 +727,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        args = self.assertRenderedError("mapping_error")
         self.assertTrue(
-            str(e.value).startswith(
+            args[2].startswith(
                 "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
             )
         )
@@ -788,28 +742,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@TEST_USER_2:test")
 
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        userinfo = {
-            "sub": "test2",
-            "username": "föö",
-        }
-        token = {}
-
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
         {
@@ -822,6 +765,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -830,14 +776,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", ANY, ANY, None, new_user=True
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -853,12 +798,128 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+
+    def test_empty_localpart(self):
+        """Attempts to map onto an empty localpart should be rejected."""
+        userinfo = {
+            "sub": "tester",
+            "username": "",
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+    @override_config(
+        {
+            "oidc_config": {
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.username }}"}
+                }
+            }
+        }
+    )
+    def test_null_localpart(self):
+        """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+        userinfo = {
+            "sub": "tester",
+            "username": None,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+    def _generate_oidc_session_token(
+        self,
+        state: str,
+        nonce: str,
+        client_redirect_url: str,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        from synapse.handlers.oidc_handler import OidcSessionData
+
+        return self.handler._token_generator.generate_oidc_session_token(
+            state=state,
+            session_data=OidcSessionData(
+                idp_id="oidc",
+                nonce=nonce,
+                client_redirect_url=client_redirect_url,
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
+
+
+async def _make_callback_with_userinfo(
+    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+    """Mock up an OIDC callback with the given userinfo dict
+
+    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+    Args:
+        hs: the HomeServer impl to send the callback to.
+        userinfo: the OIDC userinfo dict
+        client_redirect_url: the URL to redirect to on success.
+    """
+    from synapse.handlers.oidc_handler import OidcSessionData
+
+    handler = hs.get_oidc_handler()
+    provider = handler._providers["oidc"]
+    provider._exchange_code = simple_async_mock(return_value={})
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+    state = "state"
+    session = handler._token_generator.generate_oidc_session_token(
+        state=state,
+        session_data=OidcSessionData(
+            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+        ),
+    )
+    request = _build_callback_request("code", state, session)
+
+    await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+    code: str,
+    state: str,
+    session: str,
+    user_agent: str = "Browser",
+    ip_address: str = "10.0.0.1",
+):
+    """Builds a fake SynapseRequest to mock the browser callback
+
+    Returns a Mock object which looks like the SynapseRequest we get from a browser
+    after SSO (before we return to the client)
+
+    Args:
+        code: the authorization code which would have been returned by the OIDC
+           provider
+        state: the "state" param which would have been passed around in the
+           query param. Should be the same as was embedded in the session in
+           _build_oidc_session.
+        session: the "session" which would have been passed around in the cookie.
+        user_agent: the user-agent to present
+        ip_address: the IP address to pretend the request came from
+    """
+    request = Mock(
+        spec=[
+            "args",
+            "getCookie",
+            "addCookie",
+            "requestHeaders",
+            "getClientIP",
+            "getHeader",
+        ]
+    )
+
+    request.getCookie.return_value = session
+    request.args = {}
+    request.args[b"code"] = [code.encode("utf-8")]
+    request.args[b"state"] = [state.encode("utf-8")]
+    request.getClientIP.return_value = ip_address
+    return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..f816594ee4 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
+            **providers_config(CustomAuthProvider),
+            "password_config": {"enabled": False, "localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_localdb_enabled(self):
+        """Check the localdb_enabled == enabled == False
+
+        Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+        that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+        cause an exception.
+        """
+        self.register_user("localuser", "localpass")
+
+        flows = self._get_login_flows()
+        self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+        # login shouldn't work and should be rejected with a 400 ("unknown login type")
+        channel = self._send_password_login("localuser", "localpass")
+        self.assertEqual(channel.code, 400, channel.result)
+        mock_password_provider.check_auth.assert_not_called()
+
+    @override_config(
+        {
             **providers_config(PasswordCustomAuthProvider),
             "password_config": {"enabled": False},
         }
@@ -528,7 +551,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
 
     def _get_login_flows(self) -> JsonDict:
-        _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body["flows"]
 
@@ -537,7 +560,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     def _send_login(self, type, user, **params) -> FakeChannel:
         params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
-        _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+        channel = self.make_request("POST", "/_matrix/client/r0/login", params)
         return channel
 
     def _start_delete_device_session(self, access_token, device_id) -> str:
@@ -574,7 +597,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
-        _, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "devices/" + device, body, access_token=access_token
         )
         return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 8ed67640f8..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            "server", http_client=None, federation_sender=Mock()
+            "server", federation_http_client=None, federation_sender=Mock()
         )
         return hs
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..022943a10a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
 
         hs = yield setup_test_homeserver(
             self.addCleanup,
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
@@ -107,6 +105,21 @@ class ProfileTestCase(unittest.TestCase):
             "Frank",
         )
 
+        # Set displayname to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), ""
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
@@ -225,6 +238,21 @@ class ProfileTestCase(unittest.TestCase):
             "http://my.server/me.png",
         )
 
+        # Set avatar to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank, synapse.types.create_requester(self.frank), "",
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
+        )
+
     @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 45dc17aba5..a8d6c0f617 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,13 +12,35 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+from typing import Optional
+
+from mock import Mock
+
 import attr
 
 from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
 
+from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+# Check if we have the dependencies to run the tests.
+try:
+    import saml2.config
+    from saml2.sigver import SigverError
+
+    has_saml2 = True
+
+    # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
+    config = saml2.config.SPConfig()
+    try:
+        config.load({"metadata": {}})
+        has_xmlsec1 = True
+    except SigverError:
+        has_xmlsec1 = False
+except ImportError:
+    has_saml2 = False
+    has_xmlsec1 = False
+
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
 
@@ -26,6 +48,8 @@ BASE_URL = "https://synapse/"
 @attr.s
 class FakeAuthnResponse:
     ava = attr.ib(type=dict)
+    assertions = attr.ib(type=list, factory=list)
+    in_response_to = attr.ib(type=Optional[str], default=None)
 
 
 class TestMappingProvider:
@@ -86,17 +110,29 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         return hs
 
+    if not has_saml2:
+        skip = "Requires pysaml2"
+    elif not has_xmlsec1:
+        skip = "Requires xmlsec1"
+
     def test_map_saml_response_to_user(self):
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
-        # The redirect_url doesn't matter with the default user mapping provider.
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
     def test_map_saml_response_to_existing_user(self):
@@ -106,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
 
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
             {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
         )
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "", None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "", None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     def test_map_saml_response_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # mock out the error renderer too
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
-        redirect_url = ""
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
+        )
+        sso_handler.render_error.assert_called_once_with(
+            request, "mapping_error", "localpart is invalid: föö"
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        auth_handler.complete_sso_login.assert_not_called()
 
     def test_map_saml_response_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+        # stub out the auth handler and error renderer
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
+        # register a user to occupy the first-choice MXID
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
+
+        # send the fake SAML response
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", request, "", None, new_user=True
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular SAML username.
         self.get_success(
@@ -165,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # Now attempt to map to a username, this will fail since all potential usernames are taken.
         saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+        sso_handler.render_error.assert_called_once_with(
+            request,
+            "mapping_error",
+            "Unable to generate a Matrix ID from the SSO response",
         )
+        auth_handler.complete_sso_login.assert_not_called()
 
     @override_config(
         {
@@ -185,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
         }
     )
     def test_map_saml_response_redirect(self):
+        """Test a mapping provider that raises a RedirectException"""
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
+        request = _mock_request()
         e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
+            self.handler._handle_authn_response(request, saml_response, ""),
             RedirectException,
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..96e5bdac4a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
 
 
 import json
+from typing import Dict
 
 from mock import ANY, Mock, call
 
 from twisted.internet import defer
+from twisted.web.resource import Resource
 
 from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.types import UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
-from tests.utils import register_federation_servlets
 
 # Some local users to test with
 U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    servlets = [register_federation_servlets]
-
     def make_homeserver(self, reactor, clock):
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
@@ -70,13 +70,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(
             notifier=Mock(),
-            http_client=mock_federation_client,
+            federation_http_client=mock_federation_client,
             keyring=mock_keyring,
             replication_streams={},
         )
 
         return hs
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TransportLayerServer(self.hs)
+        return d
+
     def prepare(self, reactor, clock, hs):
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event
@@ -192,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
@@ -215,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        (request, channel) = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/federation/v1/send/1000000",
             _make_edu_transaction_json(
@@ -270,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 98e5af2072..9c886d671a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
                 user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
             )
         )
+        regular_user_id = "@regular:test"
+        self.get_success(
+            self.store.register_user(user_id=regular_user_id, password_hash=None)
+        )
 
         self.get_success(
             self.handler.handle_local_profile_change(support_user_id, None)
@@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         display_name = "display_name"
 
         profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
-        regular_user_id = "@regular:test"
         self.get_success(
             self.handler.handle_local_profile_change(regular_user_id, profile_info)
         )
         profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
         self.assertTrue(profile["display_name"] == display_name)
 
+    def test_handle_local_profile_change_with_deactivated_user(self):
+        # create user
+        r_user_id = "@regular:test"
+        self.get_success(
+            self.store.register_user(user_id=r_user_id, password_hash=None)
+        )
+
+        # update profile
+        display_name = "Regular User"
+        profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+        self.get_success(
+            self.handler.handle_local_profile_change(r_user_id, profile_info)
+        )
+
+        # profile is in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile["display_name"] == display_name)
+
+        # deactivate user
+        self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
+        self.get_success(self.handler.handle_user_deactivated(r_user_id))
+
+        # profile is not in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile is None)
+
+        # update profile after deactivation
+        self.get_success(
+            self.handler.handle_local_profile_change(r_user_id, profile_info)
+        )
+
+        # profile is furthermore not in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile is None)
+
     def test_handle_user_deactivated_support_user(self):
         s_user_id = "@support:test"
         self.get_success(
@@ -270,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         spam_checker = self.hs.get_spam_checker()
 
         class AllowAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # Allow all users.
                 return False
 
@@ -283,7 +321,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Configure a spam checker that filters all users.
         class BlockAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # All users are spammy.
                 return True
 
@@ -534,7 +572,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2)
 
         # Assert user directory is not empty
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", b"user_directory/search", b'{"search_term":"user2"}'
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -542,7 +580,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
 
         # Disable user directory and check search returns nothing
         self.config.user_directory_search_enabled = False
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", b"user_directory/search", b'{"search_term":"user2"}'
         )
         self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 8b5ad4574f..686012dd25 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -17,6 +17,7 @@ import logging
 from mock import Mock
 
 import treq
+from netaddr import IPSet
 from service_identity import VerificationError
 from zope.interface import implementer
 
@@ -35,6 +36,7 @@ from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
 from synapse.http.federation.srv_resolver import Server
 from synapse.http.federation.well_known_resolver import (
+    WELL_KNOWN_MAX_SIZE,
     WellKnownResolver,
     _cache_period_from_headers,
 )
@@ -103,6 +105,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             reactor=self.reactor,
             tls_client_options_factory=self.tls_factory,
             user_agent="test-agent",  # Note that this is unused since _well_known_resolver is provided.
+            ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=self.well_known_resolver,
         )
@@ -736,6 +739,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             reactor=self.reactor,
             tls_client_options_factory=tls_factory,
             user_agent=b"test-agent",  # This is unused since _well_known_resolver is passed below.
+            ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=WellKnownResolver(
                 self.reactor,
@@ -1091,7 +1095,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # Expire both caches and repeat the request
         self.reactor.pump((10000.0,))
 
-        # Repated the request, this time it should fail if the lookup fails.
+        # Repeat the request, this time it should fail if the lookup fails.
         fetch_d = defer.ensureDeferred(
             self.well_known_resolver.get_well_known(b"testserv")
         )
@@ -1104,6 +1108,32 @@ class MatrixFederationAgentTests(unittest.TestCase):
         r = self.successResultOf(fetch_d)
         self.assertEqual(r.delegated_server, None)
 
+    def test_well_known_too_large(self):
+        """A well-known query that returns a result which is too large should be rejected."""
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
+
+        # there should be an attempt to connect on port 443 for the .well-known
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        self._handle_well_known_connection(
+            client_factory,
+            expected_sni=b"testserv",
+            response_headers={b"Cache-Control": b"max-age=1000"},
+            content=b'{ "m.server": "' + (b"a" * WELL_KNOWN_MAX_SIZE) + b'" }',
+        )
+
+        # The result is successful, but disabled delegation.
+        r = self.successResultOf(fetch_d)
+        self.assertIsNone(r.delegated_server)
+
     def test_srv_fallbacks(self):
         """Test that other SRV results are tried if the first one fails.
         """
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 05e9c449be..453391a5a5 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -46,16 +46,16 @@ class AdditionalResourceTests(HomeserverTestCase):
         handler = _AsyncTestCustomEndpoint({}, None).handle_request
         resource = AdditionalResource(self.hs, handler)
 
-        request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+        channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
 
     def test_sync(self):
         handler = _SyncTestCustomEndpoint({}, None).handle_request
         resource = AdditionalResource(self.hs, handler)
 
-        request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+        channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
new file mode 100644
index 0000000000..f17c122e93
--- /dev/null
+++ b/tests/http/test_client.py
@@ -0,0 +1,101 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from io import BytesIO
+
+from mock import Mock
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
+from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
+
+from tests.unittest import TestCase
+
+
+class ReadBodyWithMaxSizeTests(TestCase):
+    def setUp(self):
+        """Start reading the body, returns the response, result and proto"""
+        self.response = Mock()
+        self.result = BytesIO()
+        self.deferred = read_body_with_max_size(self.response, self.result, 6)
+
+        # Fish the protocol out of the response.
+        self.protocol = self.response.deliverBody.call_args[0][0]
+        self.protocol.transport = Mock()
+
+    def _cleanup_error(self):
+        """Ensure that the error in the Deferred is handled gracefully."""
+        called = [False]
+
+        def errback(f):
+            called[0] = True
+
+        self.deferred.addErrback(errback)
+        self.assertTrue(called[0])
+
+    def test_no_error(self):
+        """A response that is NOT too large."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"12345")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"12345")
+        self.assertEqual(self.deferred.result, 5)
+
+    def test_too_large(self):
+        """A response which is too large raises an exception."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"1234567890")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"1234567890")
+        self.assertIsInstance(self.deferred.result, Failure)
+        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+        self._cleanup_error()
+
+    def test_multiple_packets(self):
+        """Data should be accummulated through mutliple packets."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"12")
+        self.protocol.dataReceived(b"34")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"1234")
+        self.assertEqual(self.deferred.result, 4)
+
+    def test_additional_data(self):
+        """A connection can receive data after being closed."""
+
+        # Start sending data.
+        self.protocol.dataReceived(b"1234567890")
+        self.assertIsInstance(self.deferred.result, Failure)
+        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+        self.protocol.transport.loseConnection.assert_called_once()
+
+        # More data might have come in.
+        self.protocol.dataReceived(b"1234567890")
+        # Close the connection.
+        self.protocol.connectionLost(Failure(ResponseDone()))
+
+        self.assertEqual(self.result.getvalue(), b"1234567890")
+        self.assertIsInstance(self.deferred.result, Failure)
+        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
+        self._cleanup_error()
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index b2e9533b07..d06ea518ce 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name
+from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name
 
 from tests import unittest
 
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 212484a7fe..9c52c8fdca 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -560,4 +560,4 @@ class FederationClientTests(HomeserverTestCase):
         self.pump()
 
         f = self.failureResultOf(test_d)
-        self.assertIsInstance(f.value, ValueError)
+        self.assertIsInstance(f.value, RequestSendFailed)
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 22abf76515..9a56e1c14a 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -15,12 +15,14 @@
 import logging
 
 import treq
+from netaddr import IPSet
 
 from twisted.internet import interfaces  # noqa: F401
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.web.http import HTTPChannel
 
+from synapse.http.client import BlacklistingReactorWrapper
 from synapse.http.proxyagent import ProxyAgent
 
 from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
@@ -292,6 +294,134 @@ class MatrixFederationAgentTests(TestCase):
         body = self.successResultOf(treq.content(resp))
         self.assertEqual(body, b"result")
 
+    def test_http_request_via_proxy_with_blacklist(self):
+        # The blacklist includes the configured proxy IP.
+        agent = ProxyAgent(
+            BlacklistingReactorWrapper(
+                self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+            ),
+            self.reactor,
+            http_proxy=b"proxy.com:8888",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"http://test.com")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 8888)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"http://test.com")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_https_request_via_proxy_with_blacklist(self):
+        # The blacklist includes the configured proxy IP.
+        agent = ProxyAgent(
+            BlacklistingReactorWrapper(
+                self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
+            ),
+            self.reactor,
+            contextFactory=get_test_https_policy(),
+            https_proxy=b"proxy.com",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"https://test.com/abc")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 1080)
+
+        # make a test HTTP server, and wire up the client
+        proxy_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # fish the transports back out so that we can do the old switcheroo
+        s2c_transport = proxy_server.transport
+        client_protocol = s2c_transport.other
+        c2s_transport = client_protocol.transport
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending CONNECT request
+        self.assertEqual(len(proxy_server.requests), 1)
+
+        request = proxy_server.requests[0]
+        self.assertEqual(request.method, b"CONNECT")
+        self.assertEqual(request.path, b"test.com:443")
+
+        # tell the proxy server not to close the connection
+        proxy_server.persistent = True
+
+        # this just stops the http Request trying to do a chunked response
+        # request.setHeader(b"Content-Length", b"0")
+        request.finish()
+
+        # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+        ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+        ssl_protocol = ssl_factory.buildProtocol(None)
+        http_server = ssl_protocol.wrappedProtocol
+
+        ssl_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, ssl_protocol)
+        )
+        c2s_transport.other = ssl_protocol
+
+        self.reactor.advance(0)
+
+        server_name = ssl_protocol._tlsConnection.get_servername()
+        expected_sni = b"test.com"
+        self.assertEqual(
+            server_name,
+            expected_sni,
+            "Expected SNI %s but got %s" % (expected_sni, server_name),
+        )
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/abc")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
 
 def _wrap_server_factory_for_tls(factory, sanlist=None):
     """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 73f469b802..48a74e2eee 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -18,30 +18,35 @@ import logging
 from io import StringIO
 
 from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+from synapse.logging.context import LoggingContext, LoggingContextFilter
 
 from tests.logging import LoggerCleanupMixin
 from tests.unittest import TestCase
 
 
 class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+    def setUp(self):
+        self.output = StringIO()
+
+    def get_log_line(self):
+        # One log message, with a single trailing newline.
+        data = self.output.getvalue()
+        logs = data.splitlines()
+        self.assertEqual(len(logs), 1)
+        self.assertEqual(data.count("\n"), 1)
+        return json.loads(logs[0])
+
     def test_terse_json_output(self):
         """
         The Terse JSON formatter converts log messages to JSON.
         """
-        output = StringIO()
-
-        handler = logging.StreamHandler(output)
+        handler = logging.StreamHandler(self.output)
         handler.setFormatter(TerseJsonFormatter())
         logger = self.get_logger(handler)
 
         logger.info("Hello there, %s!", "wally")
 
-        # One log message, with a single trailing newline.
-        data = output.getvalue()
-        logs = data.splitlines()
-        self.assertEqual(len(logs), 1)
-        self.assertEqual(data.count("\n"), 1)
-        log = json.loads(logs[0])
+        log = self.get_log_line()
 
         # The terse logger should give us these keys.
         expected_log_keys = [
@@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         """
         Additional information can be included in the structured logging.
         """
-        output = StringIO()
-
-        handler = logging.StreamHandler(output)
+        handler = logging.StreamHandler(self.output)
         handler.setFormatter(TerseJsonFormatter())
         logger = self.get_logger(handler)
 
@@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
             "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
         )
 
-        # One log message, with a single trailing newline.
-        data = output.getvalue()
-        logs = data.splitlines()
-        self.assertEqual(len(logs), 1)
-        self.assertEqual(data.count("\n"), 1)
-        log = json.loads(logs[0])
+        log = self.get_log_line()
 
         # The terse logger should give us these keys.
         expected_log_keys = [
@@ -96,26 +94,44 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         """
         The Terse JSON formatter converts log messages to JSON.
         """
-        output = StringIO()
-
-        handler = logging.StreamHandler(output)
+        handler = logging.StreamHandler(self.output)
         handler.setFormatter(JsonFormatter())
         logger = self.get_logger(handler)
 
         logger.info("Hello there, %s!", "wally")
 
-        # One log message, with a single trailing newline.
-        data = output.getvalue()
-        logs = data.splitlines()
-        self.assertEqual(len(logs), 1)
-        self.assertEqual(data.count("\n"), 1)
-        log = json.loads(logs[0])
+        log = self.get_log_line()
+
+        # The terse logger should give us these keys.
+        expected_log_keys = [
+            "log",
+            "level",
+            "namespace",
+        ]
+        self.assertCountEqual(log.keys(), expected_log_keys)
+        self.assertEqual(log["log"], "Hello there, wally!")
+
+    def test_with_context(self):
+        """
+        The logging context should be added to the JSON response.
+        """
+        handler = logging.StreamHandler(self.output)
+        handler.setFormatter(JsonFormatter())
+        handler.addFilter(LoggingContextFilter())
+        logger = self.get_logger(handler)
+
+        with LoggingContext(request="test"):
+            logger.info("Hello there, %s!", "wally")
+
+        log = self.get_log_line()
 
         # The terse logger should give us these keys.
         expected_log_keys = [
             "log",
             "level",
             "namespace",
+            "request",
         ]
         self.assertCountEqual(log.keys(), expected_log_keys)
         self.assertEqual(log["log"], "Hello there, wally!")
+        self.assertEqual(log["request"], "test")
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index bcdcafa5a9..c4e1e7ed85 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -187,6 +187,36 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about those messages
         self._check_for_mail()
 
+    def test_multiple_rooms(self):
+        # We want to test multiple notifications from multiple rooms, so we pause
+        # processing of push while we send messages.
+        self.pusher._pause_processing()
+
+        # Create a simple room with multiple other users
+        rooms = [
+            self.helper.create_room_as(self.user_id, tok=self.access_token),
+            self.helper.create_room_as(self.user_id, tok=self.access_token),
+        ]
+
+        for r, other in zip(rooms, self.others):
+            self.helper.invite(
+                room=r, src=self.user_id, tok=self.access_token, targ=other.id
+            )
+            self.helper.join(room=r, user=other.id, tok=other.token)
+
+        # The other users send some messages
+        self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
+        self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+        self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+
+        # Nothing should have happened yet, as we're paused.
+        assert not self.email_attempts
+
+        self.pusher._resume_processing()
+
+        # We should get emailed about those messages
+        self._check_for_mail()
+
     def test_encrypted_message(self):
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.invite(
@@ -209,7 +239,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Advance time a bit, so the pusher will register something has happened
         self.pump(10)
@@ -220,7 +250,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+        self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
         # One email was attempted to be sent
         self.assertEqual(len(self.email_attempts), 1)
@@ -238,4 +268,4 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index f118430309..60f0820cff 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -18,6 +18,7 @@ from twisted.internet.defer import Deferred
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
+from synapse.push import PusherConfigException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import receipts
 
@@ -34,6 +35,11 @@ class HTTPPusherTests(HomeserverTestCase):
     user_id = True
     hijack_auth = False
 
+    def default_config(self):
+        config = super().default_config()
+        config["start_pushers"] = True
+        return config
+
     def make_homeserver(self, reactor, clock):
         self.push_attempts = []
 
@@ -46,13 +52,49 @@ class HTTPPusherTests(HomeserverTestCase):
 
         m.post_json_get_json = post_json_get_json
 
-        config = self.default_config()
-        config["start_pushers"] = True
-
-        hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
+        hs = self.setup_test_homeserver(proxied_blacklisted_http_client=m)
 
         return hs
 
+    def test_invalid_configuration(self):
+        """Invalid push configurations should be rejected."""
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple.token_id
+
+        def test_data(data):
+            self.get_failure(
+                self.hs.get_pusherpool().add_pusher(
+                    user_id=user_id,
+                    access_token=token_id,
+                    kind="http",
+                    app_id="m.http",
+                    app_display_name="HTTP Push Notifications",
+                    device_display_name="pushy push",
+                    pushkey="a@example.com",
+                    lang=None,
+                    data=data,
+                ),
+                PusherConfigException,
+            )
+
+        # Data must be provided with a URL.
+        test_data(None)
+        test_data({})
+        test_data({"url": 1})
+        # A bare domain name isn't accepted.
+        test_data({"url": "example.com"})
+        # A URL without a path isn't accepted.
+        test_data({"url": "http://example.com"})
+        # A url with an incorrect path isn't accepted.
+        test_data({"url": "http://example.com/foo"})
+
     def test_sends_http(self):
         """
         The HTTP pusher will send pushes for each message to a HTTP endpoint
@@ -82,7 +124,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -102,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
@@ -113,11 +155,13 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+        self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
         # One push was attempted to be sent -- it'll be the first message
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(
             self.push_attempts[0][2]["notification"]["content"]["body"], "Hi!"
         )
@@ -132,12 +176,14 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Now it'll try and send the second push message, which will be the second one
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(
             self.push_attempts[1][2]["notification"]["content"]["body"], "There!"
         )
@@ -152,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
     def test_sends_high_priority_for_encrypted(self):
         """
@@ -194,7 +240,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -230,7 +276,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Add yet another person — we want to make this room not a 1:1
@@ -268,7 +316,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
 
     def test_sends_high_priority_for_one_to_one_only(self):
@@ -310,7 +360,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -326,7 +376,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority — this is a one-to-one room
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Yet another user joins
@@ -345,7 +397,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -392,7 +446,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -408,7 +462,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Send another event, this time with no mention
@@ -417,7 +473,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -465,7 +523,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -485,7 +543,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it with high priority
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
         self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
 
         # Send another event, this time as someone without the power of @room
@@ -496,7 +556,9 @@ class HTTPPusherTests(HomeserverTestCase):
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
         self.assertEqual(len(self.push_attempts), 2)
-        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[1][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
@@ -570,7 +632,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -589,7 +651,9 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Check our push made it
         self.assertEqual(len(self.push_attempts), 1)
-        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(
+            self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify"
+        )
 
         # Check that the unread count for the room is 0
         #
@@ -603,7 +667,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # This will actually trigger a new notification to be sent out so that
         # even if the user does not receive another message, their unread
         # count goes down
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
             {},
diff --git a/tests/push/test_presentable_names.py b/tests/push/test_presentable_names.py
new file mode 100644
index 0000000000..aff563919d
--- /dev/null
+++ b/tests/push/test_presentable_names.py
@@ -0,0 +1,229 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from typing import Iterable, Optional, Tuple
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.events import FrozenEvent
+from synapse.push.presentable_names import calculate_room_name
+from synapse.types import StateKey, StateMap
+
+from tests import unittest
+
+
+class MockDataStore:
+    """
+    A fake data store which stores a mapping of state key to event content.
+    (I.e. the state key is used as the event ID.)
+    """
+
+    def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
+        """
+        Args:
+            events: A state map to event contents.
+        """
+        self._events = {}
+
+        for i, (event_id, content) in enumerate(events):
+            self._events[event_id] = FrozenEvent(
+                {
+                    "event_id": "$event_id",
+                    "type": event_id[0],
+                    "sender": "@user:test",
+                    "state_key": event_id[1],
+                    "room_id": "#room:test",
+                    "content": content,
+                    "origin_server_ts": i,
+                },
+                RoomVersions.V1,
+            )
+
+    async def get_event(
+        self, event_id: StateKey, allow_none: bool = False
+    ) -> Optional[FrozenEvent]:
+        assert allow_none, "Mock not configured for allow_none = False"
+
+        return self._events.get(event_id)
+
+    async def get_events(self, event_ids: Iterable[StateKey]):
+        # This is cheating since it just returns all events.
+        return self._events
+
+
+class PresentableNamesTestCase(unittest.HomeserverTestCase):
+    USER_ID = "@test:test"
+    OTHER_USER_ID = "@user:test"
+
+    def _calculate_room_name(
+        self,
+        events: StateMap[dict],
+        user_id: str = "",
+        fallback_to_members: bool = True,
+        fallback_to_single_member: bool = True,
+    ):
+        # This isn't 100% accurate, but works with MockDataStore.
+        room_state_ids = {k[0]: k[0] for k in events}
+
+        return self.get_success(
+            calculate_room_name(
+                MockDataStore(events),
+                room_state_ids,
+                user_id or self.USER_ID,
+                fallback_to_members,
+                fallback_to_single_member,
+            )
+        )
+
+    def test_name(self):
+        """A room name event should be used."""
+        events = [
+            ((EventTypes.Name, ""), {"name": "test-name"}),
+        ]
+        self.assertEqual("test-name", self._calculate_room_name(events))
+
+        # Check if the event content has garbage.
+        events = [((EventTypes.Name, ""), {"foo": 1})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        events = [((EventTypes.Name, ""), {"name": 1})]
+        self.assertEqual(1, self._calculate_room_name(events))
+
+    def test_canonical_alias(self):
+        """An canonical alias should be used."""
+        events = [
+            ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
+        ]
+        self.assertEqual("#test-name:test", self._calculate_room_name(events))
+
+        # Check if the event content has garbage.
+        events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+    def test_invite(self):
+        """An invite has special behaviour."""
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+            ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
+        ]
+        self.assertEqual("Invite from Other User", self._calculate_room_name(events))
+        self.assertIsNone(
+            self._calculate_room_name(events, fallback_to_single_member=False)
+        )
+        # Ensure this logic is skipped if we don't fallback to members.
+        self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
+
+        # Check if the event content has garbage.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+            ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+        ]
+        self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
+
+        # No member event for sender.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
+        ]
+        self.assertEqual("Room Invite", self._calculate_room_name(events))
+
+    def test_no_members(self):
+        """Behaviour of an empty room."""
+        events = []
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+        # Note that events with invalid (or missing) membership are ignored.
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
+            ((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
+        ]
+        self.assertEqual("Empty Room", self._calculate_room_name(events))
+
+    def test_no_other_members(self):
+        """Behaviour of a room with no other members in it."""
+        events = [
+            (
+                (EventTypes.Member, self.USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Me"},
+            ),
+        ]
+        self.assertEqual("Me", self._calculate_room_name(events))
+
+        # Check if the event content has no displayname.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual("@test:test", self._calculate_room_name(events))
+
+        # 3pid invite, use the other user (who is set as the sender).
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual(
+            "nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
+        )
+
+        events = [
+            ((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
+            ((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
+        ]
+        self.assertEqual(
+            "Inviting email address",
+            self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
+        )
+
+    def test_one_other_member(self):
+        """Behaviour of a room with a single other member."""
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Other User"},
+            ),
+        ]
+        self.assertEqual("Other User", self._calculate_room_name(events))
+        self.assertIsNone(
+            self._calculate_room_name(events, fallback_to_single_member=False)
+        )
+
+        # Check if the event content has no displayname and is an invite.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.INVITE},
+            ),
+        ]
+        self.assertEqual("@user:test", self._calculate_room_name(events))
+
+    def test_other_members(self):
+        """Behaviour of a room with multiple other members."""
+        # Two other members.
+        events = [
+            ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
+            (
+                (EventTypes.Member, self.OTHER_USER_ID),
+                {"membership": Membership.JOIN, "displayname": "Other User"},
+            ),
+            ((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
+        ]
+        self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
+
+        # Three or more other members.
+        events.append(
+            ((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
+        )
+        self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 1f4b5ca2ac..4a841f5bb8 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
                 "type": "m.room.history_visibility",
                 "sender": "@user:test",
                 "state_key": "",
-                "room_id": "@room:test",
+                "room_id": "#room:test",
                 "content": content,
             },
             RoomVersions.V1,
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import attr
 
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
 )
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    servlets = [
-        streams.register_servlets,
-    ]
-
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.worker_hs = self.setup_test_homeserver(
-            http_client=None,
+            federation_http_client=None,
             homeserver_to_use=GenericWorkerServer,
             config=self._get_worker_hs_config(),
             reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+        return d
+
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.generic_worker"
@@ -210,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Fake in memory Redis server that servers can connect to.
         self._redis_server = FakeRedisPubSubServer()
 
+        # We may have an attempt to connect to redis for the external cache already.
+        self.connect_any_redis_attempts()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
@@ -264,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
             extra_config: Any extra config to use for this instances.
             **kwargs: Options that get passed to `self.setup_test_homeserver`,
-                useful to e.g. pass some mocks for things like `http_client`
+                useful to e.g. pass some mocks for things like `federation_http_client`
 
         Returns:
             The new worker HomeServer instance.
@@ -399,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         fake one.
         """
         clients = self.reactor.tcpClients
-        self.assertEqual(len(clients), 1)
-        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-        self.assertEqual(host, "localhost")
-        self.assertEqual(port, 6379)
+        while clients:
+            (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+            self.assertEqual(host, "localhost")
+            self.assertEqual(port, 6379)
 
-        client_protocol = client_factory.buildProtocol(None)
-        server_protocol = self._redis_server.buildProtocol(None)
+            client_protocol = client_factory.buildProtocol(None)
+            server_protocol = self._redis_server.buildProtocol(None)
 
-        client_to_server_transport = FakeTransport(
-            server_protocol, self.reactor, client_protocol
-        )
-        client_protocol.makeConnection(client_to_server_transport)
-
-        server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, server_protocol
-        )
-        server_protocol.makeConnection(server_to_client_transport)
+            client_to_server_transport = FakeTransport(
+                server_protocol, self.reactor, client_protocol
+            )
+            client_protocol.makeConnection(client_to_server_transport)
 
-        return client_to_server_transport, server_to_client_transport
+            server_to_client_transport = FakeTransport(
+                client_protocol, self.reactor, server_protocol
+            )
+            server_protocol.makeConnection(server_to_client_transport)
 
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -622,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
             (channel,) = args
             self._server.add_subscriber(self)
             self.send(["subscribe", channel, 1])
+
+        # Since we use SET/GET to cache things we can safely no-op them.
+        elif command == b"SET":
+            self.send("OK")
+        elif command == b"GET":
+            self.send(None)
         else:
             raise Exception("Unknown command")
 
@@ -643,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
             # We assume bytes are just unicode strings.
             obj = obj.decode("utf-8")
 
+        if obj is None:
+            return "$-1\r\n"
         if isinstance(obj, str):
             return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
         if isinstance(obj, int):
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..f35a5235e1
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+    """Test the authentication of HTTP calls between workers."""
+
+    servlets = [register.register_servlets]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        # This isn't a real configuration option but is used to provide the main
+        # homeserver and worker homeserver different options.
+        main_replication_secret = config.pop("main_replication_secret", None)
+        if main_replication_secret:
+            config["worker_replication_secret"] = main_replication_secret
+        return self.setup_test_homeserver(config=config)
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.client_reader"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+
+        return config
+
+    def _test_register(self) -> FakeChannel:
+        """Run the actual test:
+
+        1. Create a worker homeserver.
+        2. Start registration by providing a user/password.
+        3. Complete registration by providing dummy auth (this hits the main synapse).
+        4. Return the final request.
+
+        """
+        worker_hs = self.make_worker_hs("synapse.app.client_reader")
+        site = self._hs_to_site[worker_hs]
+
+        channel_1 = make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel_1.code, 401)
+
+        # Grab the session
+        session = channel_1.json_body["session"]
+
+        # also complete the dummy auth
+        return make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": "m.login.dummy"}},
+        )
+
+    def test_no_auth(self):
+        """With no authentication the request should finish.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+    @override_config({"main_replication_secret": "my-secret"})
+    def test_missing_auth(self):
+        """If the main process expects a secret that is not provided, an error results.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 500)
+
+    @override_config(
+        {
+            "main_replication_secret": "my-secret",
+            "worker_replication_secret": "wrong-secret",
+        }
+    )
+    def test_unauthorized(self):
+        """If the main process receives the wrong secret, an error results.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 500)
+
+    @override_config({"worker_replication_secret": "my-secret"})
+    def test_authorized(self):
+        """The request should finish when the worker provides the authentication header.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..4608b65a0c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@
 # limitations under the License.
 import logging
 
-from synapse.api.constants import LoginType
-from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha import register
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, make_request
+from tests.server import make_request
 
 logger = logging.getLogger(__name__)
 
 
 class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
-    """Base class for tests of the replication streams"""
+    """Test using one or more client readers for registration."""
 
     servlets = [register.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
-        self.recaptcha_checker = DummyRecaptchaChecker(hs)
-        auth_handler = hs.get_auth_handler()
-        auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.client_reader"
@@ -48,27 +40,27 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         worker_hs = self.make_worker_hs("synapse.app.client_reader")
         site = self._hs_to_site[worker_hs]
 
-        request_1, channel_1 = make_request(
+        channel_1 = make_request(
             self.reactor,
             site,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_1.code, 401)
+        )
+        self.assertEqual(channel_1.code, 401)
 
         # Grab the session
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
-        request_2, channel_2 = make_request(
+        channel_2 = make_request(
             self.reactor,
             site,
             "POST",
             "register",
             {"auth": {"session": session, "type": "m.login.dummy"}},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_2.code, 200)
+        )
+        self.assertEqual(channel_2.code, 200)
 
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
@@ -80,28 +72,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
 
         site_1 = self._hs_to_site[worker_hs_1]
-        request_1, channel_1 = make_request(
+        channel_1 = make_request(
             self.reactor,
             site_1,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_1.code, 401)
+        )
+        self.assertEqual(channel_1.code, 401)
 
         # Grab the session
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
         site_2 = self._hs_to_site[worker_hs_2]
-        request_2, channel_2 = make_request(
+        channel_2 = make_request(
             self.reactor,
             site_2,
             "POST",
             "register",
             {"auth": {"session": session, "type": "m.login.dummy"}},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_2.code, 200)
+        )
+        self.assertEqual(channel_2.code, 200)
 
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {"send_federation": True},
-            http_client=mock_client,
+            federation_http_client=mock_client,
         )
 
         user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 48b574ccbe..d1feca961f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.user_id = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
 
-        self.reactor.lookups["example.com"] = "127.0.0.2"
+        self.reactor.lookups["example.com"] = "1.2.3.4"
 
     def default_config(self):
         conf = super().default_config()
@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
             the media which the caller should respond to.
         """
         resource = hs.get_media_repository_resource().children[b"download"]
-        _, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "GET",
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "https://push.example.com/push"},
+                data={"url": "https://push.example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.pusher",
             {"start_pushers": True},
-            proxied_http_client=http_client_mock,
+            proxied_blacklisted_http_client=http_client_mock,
         )
 
         event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher1",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock1,
+            proxied_blacklisted_http_client=http_client_mock1,
         )
 
         http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher2",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock2,
+            proxied_blacklisted_http_client=http_client_mock2,
         )
 
         # We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_not_called()
         self.assertEqual(
             http_client_mock1.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock2.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 77fc3856d5..8d494ebc03 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Do an initial sync so that we're up to date.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
         )
         next_batch = channel.json_body["next_batch"]
@@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Check that syncing still gets the new event, despite the gap in the
         # stream IDs.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
         first_event_in_room2 = response["event_id"]
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
         self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         # Paginating back in the first room should not produce any results, as
         # no events have happened in it. This tests that we are correctly
         # filtering results based on the vector clock portion.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Paginating back on the second room should produce the first event
         # again. This tests that pagination isn't completely broken.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Paginating forwards should give the same results
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
         self.assertListEqual([], channel.json_body["chunk"])
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 4f76f8f768..9d22c04073 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -42,7 +42,7 @@ class VersionTestCase(unittest.HomeserverTestCase):
         return resource
 
     def test_version_string(self):
-        request, channel = self.make_request("GET", self.url, shorthand=False)
+        channel = self.make_request("GET", self.url, shorthand=False)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
@@ -58,8 +58,6 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -68,7 +66,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
     def test_delete_group(self):
         # Create a new group
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/create_group".encode("ascii"),
             access_token=self.admin_user_tok,
@@ -84,13 +82,13 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         # Invite/join another user
 
         url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url.encode("ascii"), access_token=self.admin_user_tok, content={}
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         url = "/groups/%s/self/accept_invite" % (group_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url.encode("ascii"), access_token=self.other_user_token, content={}
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -101,7 +99,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
 
         # Now delete the group
         url = "/_synapse/admin/v1/delete_group/" + group_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             access_token=self.admin_user_tok,
@@ -123,7 +121,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         """
 
         url = "/groups/%s/profile" % (group_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
 
@@ -134,7 +132,7 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
     def _get_groups_user_is_in(self, access_token):
         """Returns the list of groups the user is in (given their access token)
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/joined_groups".encode("ascii"), access_token=access_token
         )
 
@@ -155,9 +153,6 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-        self.hs = hs
-
         # Allow for uploading and downloading to/from the media repo
         self.media_repo = hs.get_media_repository_resource()
         self.download_resource = self.media_repo.children[b"download"]
@@ -210,13 +205,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         }
         config["media_storage_providers"] = [provider_config]
 
-        hs = self.setup_test_homeserver(config=config, http_client=client)
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
 
         return hs
 
     def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
         """Ensure a piece of media is quarantined when trying to access it."""
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
@@ -241,7 +236,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Attempt quarantine media APIs as non-admin
         url = "/_synapse/admin/v1/media/quarantine/example.org/abcde12345"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=non_admin_user_tok,
         )
 
@@ -254,7 +249,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # And the roomID/userID endpoint
         url = "/_synapse/admin/v1/room/!room%3Aexample.com/media/quarantine"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=non_admin_user_tok,
         )
 
@@ -282,7 +277,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         server_name, media_id = server_name_and_media_id.split("/")
 
         # Attempt to access the media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
@@ -299,7 +294,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             urllib.parse.quote(server_name),
             urllib.parse.quote(media_id),
         )
-        request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        channel = self.make_request("POST", url, access_token=admin_user_tok,)
         self.pump(1.0)
         self.assertEqual(200, int(channel.code), msg=channel.result["body"])
 
@@ -351,7 +346,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             url = "/_synapse/admin/v1/room/%s/media/quarantine" % urllib.parse.quote(
                 room_id
             )
-        request, channel = self.make_request("POST", url, access_token=admin_user_tok,)
+        channel = self.make_request("POST", url, access_token=admin_user_tok,)
         self.pump(1.0)
         self.assertEqual(200, int(channel.code), msg=channel.result["body"])
         self.assertEqual(
@@ -395,7 +390,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
             non_admin_user
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=admin_user_tok,
         )
         self.pump(1.0)
@@ -431,13 +426,17 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         # Mark the second item as safe from quarantine.
         _, media_id_2 = server_and_media_id_2.split("/")
-        self.get_success(self.store.mark_local_media_as_safe(media_id_2))
+        # Quarantine the media
+        url = "/_synapse/admin/v1/media/protect/%s" % (urllib.parse.quote(media_id_2),)
+        channel = self.make_request("POST", url, access_token=admin_user_tok)
+        self.pump(1.0)
+        self.assertEqual(200, int(channel.code), msg=channel.result["body"])
 
         # Quarantine all media by this user
         url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
             non_admin_user
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url.encode("ascii"), access_token=admin_user_tok,
         )
         self.pump(1.0)
@@ -453,7 +452,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
 
         # Attempt to access each piece of media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index cf3a007598..248c4442c3 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -50,17 +50,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         Try to get a device of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request("PUT", self.url, b"{}")
+        channel = self.make_request("PUT", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request("DELETE", self.url, b"{}")
+        channel = self.make_request("DELETE", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -69,21 +69,21 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         If the user is not a server admin, an error is returned.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url, access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", self.url, access_token=self.other_user_token,
         )
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", self.url, access_token=self.other_user_token,
         )
 
@@ -99,23 +99,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             % self.other_user_device_id
         )
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "PUT", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -129,23 +123,17 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             % self.other_user_device_id
         )
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
 
-        request, channel = self.make_request(
-            "PUT", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -158,22 +146,16 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             self.other_user
         )
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "PUT", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         # Delete unknown device returns status 200
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -197,7 +179,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         }
 
         body = json.dumps(update)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
@@ -208,9 +190,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"])
 
         # Ensure the display name was not updated.
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new display", channel.json_body["display_name"])
@@ -227,16 +207,12 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
-            "PUT", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("PUT", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Ensure the display name was not updated.
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new display", channel.json_body["display_name"])
@@ -247,7 +223,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         # Set new display_name
         body = json.dumps({"display_name": "new displayname"})
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
@@ -257,9 +233,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         # Check new display_name
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("new displayname", channel.json_body["display_name"])
@@ -268,9 +242,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         """
         Tests that a normal lookup for a device is successfully
         """
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
@@ -291,7 +263,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(1, number_devices)
 
         # Delete device
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", self.url, access_token=self.admin_user_tok,
         )
 
@@ -323,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list devices of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -334,9 +306,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -346,9 +316,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v2/users/@unknown_person:test/devices"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -359,9 +327,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/devices"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -373,9 +339,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Get devices
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -391,9 +355,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
             self.login("user", "pass")
 
         # Get devices
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_devices, channel.json_body["total"])
@@ -431,7 +393,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         Try to delete devices of an user without authentication.
         """
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -442,9 +404,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "POST", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("POST", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -454,9 +414,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v2/users/@unknown_person:test/delete_devices"
-        request, channel = self.make_request(
-            "POST", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -467,9 +425,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@unknown_person:unknown_domain/delete_devices"
 
-        request, channel = self.make_request(
-            "POST", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -479,7 +435,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         Tests that a remove of a device that does not exist returns 200.
         """
         body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
@@ -510,7 +466,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
 
         # Delete devices
         body = json.dumps({"devices": device_ids})
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 11b72c10f7..d0090faa4f 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -32,8 +32,6 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -74,7 +72,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
         Try to get an event report without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -84,9 +82,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.other_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -96,9 +92,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 20)
@@ -111,7 +105,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with limit
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
         )
 
@@ -126,7 +120,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a defined starting point (from)
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5", access_token=self.admin_user_tok,
         )
 
@@ -141,7 +135,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a defined starting point and limit
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
         )
 
@@ -156,7 +150,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a filter of room
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?room_id=%s" % self.room_id1,
             access_token=self.admin_user_tok,
@@ -176,7 +170,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a filter of user
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?user_id=%s" % self.other_user,
             access_token=self.admin_user_tok,
@@ -196,7 +190,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing list of reported events with a filter of user and room
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1),
             access_token=self.admin_user_tok,
@@ -218,7 +212,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
 
         # fetch the most recent first, largest timestamp
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=b", access_token=self.admin_user_tok,
         )
 
@@ -234,7 +228,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
             report += 1
 
         # fetch the oldest first, smallest timestamp
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=f", access_token=self.admin_user_tok,
         )
 
@@ -254,7 +248,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing that a invalid search order returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
         )
 
@@ -267,7 +261,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing that a negative limit parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
         )
 
@@ -279,7 +273,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         Testing that a negative from parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
         )
 
@@ -293,7 +287,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of results is the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
         )
 
@@ -304,7 +298,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
         )
 
@@ -315,7 +309,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
         )
 
@@ -327,7 +321,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         # Check
         # Set `from` to value of `next_token` for request remaining entries
         #  `next_token` does not appear
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=19", access_token=self.admin_user_tok,
         )
 
@@ -342,7 +336,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id, tok=user_tok)
         event_id = resp["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "rooms/%s/report/%s" % (room_id, event_id),
             json.dumps({"score": -100, "reason": "this makes me sad"}),
@@ -375,8 +369,6 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -399,7 +391,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         """
         Try to get event report without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -409,9 +401,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.other_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.other_user_tok,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -421,9 +411,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         Testing get a reported event
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self._check_fields(channel.json_body)
@@ -434,7 +422,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         """
 
         # `report_id` is negative
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/-123",
             access_token=self.admin_user_tok,
@@ -448,7 +436,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         )
 
         # `report_id` is a non-numerical string
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/abcdef",
             access_token=self.admin_user_tok,
@@ -462,7 +450,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         )
 
         # `report_id` is undefined
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/",
             access_token=self.admin_user_tok,
@@ -480,7 +468,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         Testing that a not existing `report_id` returns a 404.
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v1/event_reports/123",
             access_token=self.admin_user_tok,
@@ -496,7 +484,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id, tok=user_tok)
         event_id = resp["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "rooms/%s/report/%s" % (room_id, event_id),
             json.dumps({"score": -100, "reason": "this makes me sad"}),
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index dadf9db660..51a7731693 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -35,7 +35,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
@@ -50,7 +49,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        request, channel = self.make_request("DELETE", url, b"{}")
+        channel = self.make_request("DELETE", url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -64,9 +63,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
 
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.other_user_token,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -77,9 +74,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345")
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -90,9 +85,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345")
 
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only delete local media", channel.json_body["error"])
@@ -121,7 +114,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertEqual(server_name, self.server_name)
 
         # Attempt to access media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(download_resource),
             "GET",
@@ -146,9 +139,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id)
 
         # Delete media
-        request, channel = self.make_request(
-            "DELETE", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("DELETE", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
@@ -157,7 +148,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         )
 
         # Attempt to access media
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(download_resource),
             "GET",
@@ -189,7 +180,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_device_handler()
         self.media_repo = hs.get_media_repository_resource()
         self.server_name = hs.hostname
 
@@ -204,7 +194,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         Try to delete media without authentication.
         """
 
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -216,7 +206,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.other_user = self.register_user("user", "pass")
         self.other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, access_token=self.other_user_token,
         )
 
@@ -229,7 +219,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url + "?before_ts=1234", access_token=self.admin_user_tok,
         )
 
@@ -240,9 +230,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         If the parameter `before_ts` is missing, an error is returned.
         """
-        request, channel = self.make_request(
-            "POST", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("POST", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
@@ -254,7 +242,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         """
         If parameters are invalid, an error is returned.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok,
         )
 
@@ -265,7 +253,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=1234&size_gt=-1234",
             access_token=self.admin_user_tok,
@@ -278,7 +266,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=1234&keep_profiles=not_bool",
             access_token=self.admin_user_tok,
@@ -308,7 +296,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         # timestamp after upload/create
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
@@ -332,7 +320,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         self._access_media(server_and_media_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
@@ -344,7 +332,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         # timestamp after upload
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
@@ -367,7 +355,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&size_gt=67",
             access_token=self.admin_user_tok,
@@ -378,7 +366,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&size_gt=66",
             access_token=self.admin_user_tok,
@@ -401,7 +389,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         # set media as avatar
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/profile/%s/avatar_url" % (self.admin_user,),
             content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}),
@@ -410,7 +398,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
             access_token=self.admin_user_tok,
@@ -421,7 +409,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
             access_token=self.admin_user_tok,
@@ -445,7 +433,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         # set media as room avatar
         room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/state/m.room.avatar" % (room_id,),
             content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}),
@@ -454,7 +442,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, channel.code, msg=channel.json_body)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true",
             access_token=self.admin_user_tok,
@@ -465,7 +453,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self._access_media(server_and_media_id)
 
         now_ms = self.clock.time_msec()
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false",
             access_token=self.admin_user_tok,
@@ -512,7 +500,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         media_id = server_and_media_id.split("/")[1]
         local_path = self.filepaths.local_media_filepath(media_id)
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(download_resource),
             "GET",
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46933a0493..7c47aa7e0a 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -20,6 +20,7 @@ from typing import List, Optional
 from mock import Mock
 
 import synapse.rest.admin
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import Codes
 from synapse.rest.client.v1 import directory, events, login, room
 
@@ -79,7 +80,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -103,7 +104,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         # Enable world readable
         url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url.encode("ascii"),
             json.dumps({"history_visibility": "world_readable"}),
@@ -113,7 +114,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -130,7 +131,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         """
 
         url = "rooms/%s/initialSync" % (room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -138,7 +139,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
         )
 
         url = "events?timeout=0&room_id=" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -184,7 +185,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         If the user is not a server admin, an error 403 is returned.
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
         )
 
@@ -197,7 +198,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, json.dumps({}), access_token=self.admin_user_tok,
         )
 
@@ -210,7 +211,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/rooms/invalidroom/delete"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, json.dumps({}), access_token=self.admin_user_tok,
         )
 
@@ -225,7 +226,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"new_room_user_id": "@unknown:test"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -244,7 +245,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"new_room_user_id": "@not:exist.bla"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -262,7 +263,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"block": "NotBool"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -278,7 +279,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"purge": "NotBool"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -304,7 +305,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         body = json.dumps({"block": True, "purge": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url.encode("ascii"),
             content=body.encode(encoding="utf_8"),
@@ -337,7 +338,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         body = json.dumps({"block": False, "purge": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url.encode("ascii"),
             content=body.encode(encoding="utf_8"),
@@ -371,7 +372,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         body = json.dumps({"block": False, "purge": False})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url.encode("ascii"),
             content=body.encode(encoding="utf_8"),
@@ -418,7 +419,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -448,7 +449,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         # Enable world readable
         url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url.encode("ascii"),
             json.dumps({"history_visibility": "world_readable"}),
@@ -465,7 +466,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         # Test that the admin can still send shutdown
         url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             json.dumps({"new_room_user_id": self.admin_user}),
@@ -530,7 +531,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         """
 
         url = "rooms/%s/initialSync" % (room_id,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -538,7 +539,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         )
 
         url = "events?timeout=0&room_id=" + room_id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok
         )
         self.assertEqual(
@@ -569,7 +570,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
         self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
 
         url = "/_synapse/admin/v1/purge_room"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url.encode("ascii"),
             {"room_id": room_id},
@@ -604,8 +605,6 @@ class RoomTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         # Create user
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -623,7 +622,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Request the list of rooms
         url = "/_synapse/admin/v1/rooms"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
 
@@ -704,7 +703,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
                 limit,
                 "name",
             )
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
             )
             self.assertEqual(
@@ -744,7 +743,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(room_ids, returned_room_ids)
 
         url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -764,7 +763,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Create a new alias to this room
         url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url.encode("ascii"),
             {"room_id": room_id},
@@ -794,7 +793,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         # Request the list of rooms
         url = "/_synapse/admin/v1/rooms"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -835,7 +834,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url = "/_matrix/client/r0/directory/room/%s" % (
                 urllib.parse.quote(test_alias),
             )
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "PUT",
                 url.encode("ascii"),
                 {"room_id": room_id},
@@ -875,7 +874,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
             url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
             if reverse:
                 url += "&dir=b"
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
             )
             self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1011,7 +1010,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
                 expected_http_code: The expected http code for the request
             """
             url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET", url.encode("ascii"), access_token=self.admin_user_tok,
             )
             self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
@@ -1050,6 +1049,13 @@ class RoomTestCase(unittest.HomeserverTestCase):
         _search_test(room_id_2, "else")
         _search_test(room_id_2, "se")
 
+        # Test case insensitive
+        _search_test(room_id_1, "SOMETHING")
+        _search_test(room_id_1, "THING")
+
+        _search_test(room_id_2, "ELSE")
+        _search_test(room_id_2, "SE")
+
         _search_test(None, "foo")
         _search_test(None, "bar")
         _search_test(None, "", expected_http_code=400)
@@ -1072,7 +1078,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
 
         url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1084,6 +1090,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.assertIn("canonical_alias", channel.json_body)
         self.assertIn("joined_members", channel.json_body)
         self.assertIn("joined_local_members", channel.json_body)
+        self.assertIn("joined_local_devices", channel.json_body)
         self.assertIn("version", channel.json_body)
         self.assertIn("creator", channel.json_body)
         self.assertIn("encryption", channel.json_body)
@@ -1096,6 +1103,39 @@ class RoomTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(room_id_1, channel.json_body["room_id"])
 
+    def test_single_room_devices(self):
+        """Test that `joined_local_devices` can be requested correctly"""
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(1, channel.json_body["joined_local_devices"])
+
+        # Have another user join the room
+        user_1 = self.register_user("foo", "pass")
+        user_tok_1 = self.login("foo", "pass")
+        self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(2, channel.json_body["joined_local_devices"])
+
+        # leave room
+        self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
+        self.helper.leave(room_id_1, user_1, tok=user_tok_1)
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["joined_local_devices"])
+
     def test_room_members(self):
         """Test that room members can be requested correctly"""
         # Create two test rooms
@@ -1119,7 +1159,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.helper.join(room_id_2, user_3, tok=user_tok_3)
 
         url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1130,7 +1170,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["total"], 3)
 
         url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -1140,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.json_body["total"], 3)
 
+    def test_room_state(self):
+        """Test that room state can be requested correctly"""
+        # Create two test rooms
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+        channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertIn("state", channel.json_body)
+        # testing that the state events match is painful and not done here. We assume that
+        # the create_room already does the right thing, so no need to verify that we got
+        # the state events it created.
+
 
 class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
@@ -1170,7 +1225,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1186,7 +1241,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"unknown_parameter": "@unknown:test"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1202,7 +1257,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": "@unknown:test"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1218,7 +1273,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": "@not:exist.bla"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1238,7 +1293,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         body = json.dumps({"user_id": self.second_user_id})
         url = "/_synapse/admin/v1/join/!unknown:test"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1255,7 +1310,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         body = json.dumps({"user_id": self.second_user_id})
         url = "/_synapse/admin/v1/join/invalidroom"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1274,7 +1329,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         """
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             self.url,
             content=body.encode(encoding="utf_8"),
@@ -1286,7 +1341,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if user is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1303,7 +1358,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/join/{}".format(private_room_id)
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1333,7 +1388,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if server admin is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1344,7 +1399,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/join/{}".format(private_room_id)
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1355,7 +1410,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if user is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1372,7 +1427,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/join/{}".format(private_room_id)
         body = json.dumps({"user_id": self.second_user_id})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             url,
             content=body.encode(encoding="utf_8"),
@@ -1384,13 +1439,150 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
 
         # Validate if user is a member of the room
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
         )
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
 
+class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.creator = self.register_user("creator", "test")
+        self.creator_tok = self.login("creator", "test")
+
+        self.second_user_id = self.register_user("second", "test")
+        self.second_tok = self.login("second", "test")
+
+        self.public_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+        self.url = "/_synapse/admin/v1/rooms/{}/make_room_admin".format(
+            self.public_room_id
+        )
+
+    def test_public_room(self):
+        """Test that getting admin in a public room works.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Now we test that we can join the room and ban a user.
+        self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+        self.helper.change_membership(
+            room_id,
+            self.admin_user,
+            "@test:test",
+            Membership.BAN,
+            tok=self.admin_user_tok,
+        )
+
+    def test_private_room(self):
+        """Test that getting admin in a private room works and we get invited.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=False,
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Now we test that we can join the room (we should have received an
+        # invite) and can ban a user.
+        self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
+        self.helper.change_membership(
+            room_id,
+            self.admin_user,
+            "@test:test",
+            Membership.BAN,
+            tok=self.admin_user_tok,
+        )
+
+    def test_other_user(self):
+        """Test that giving admin in a public room works to a non-admin user works.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={"user_id": self.second_user_id},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Now we test that we can join the room and ban a user.
+        self.helper.join(room_id, self.second_user_id, tok=self.second_tok)
+        self.helper.change_membership(
+            room_id,
+            self.second_user_id,
+            "@test:test",
+            Membership.BAN,
+            tok=self.second_tok,
+        )
+
+    def test_not_enough_power(self):
+        """Test that we get a sensible error if there are no local room admins.
+        """
+        room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+
+        # The creator drops admin rights in the room.
+        pl = self.helper.get_state(
+            room_id, EventTypes.PowerLevels, tok=self.creator_tok
+        )
+        pl["users"][self.creator] = 0
+        self.helper.send_state(
+            room_id, EventTypes.PowerLevels, body=pl, tok=self.creator_tok
+        )
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/rooms/{}/make_room_admin".format(room_id),
+            content={},
+            access_token=self.admin_user_tok,
+        )
+
+        # We expect this to fail with a 400 as there are no room admins.
+        #
+        # (Note we assert the error message to ensure that it's not denied for
+        # some other reason)
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            channel.json_body["error"],
+            "No local admin user in room with power to update power levels.",
+        )
+
+
 PURGE_TABLES = [
     "current_state_events",
     "event_backward_extremities",
@@ -1419,7 +1611,6 @@ PURGE_TABLES = [
     "event_push_summary",
     "pusher_throttle",
     "group_summary_rooms",
-    "local_invites",
     "room_account_data",
     "room_tags",
     # "state_groups",  # Current impl leaves orphaned state groups around.
diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py
index 907b49f889..f48be3d65a 100644
--- a/tests/rest/admin/test_statistics.py
+++ b/tests/rest/admin/test_statistics.py
@@ -31,7 +31,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -46,7 +45,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         Try to list users without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -55,7 +54,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         If the user is not a server admin, an error 403 is returned.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url, json.dumps({}), access_token=self.other_user_tok,
         )
 
@@ -67,7 +66,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         If parameters are invalid, an error is returned.
         """
         # unkown order_by
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?order_by=bar", access_token=self.admin_user_tok,
         )
 
@@ -75,7 +74,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative from
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
         )
 
@@ -83,7 +82,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative limit
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
         )
 
@@ -91,7 +90,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative from_ts
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from_ts=-1234", access_token=self.admin_user_tok,
         )
 
@@ -99,7 +98,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative until_ts
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?until_ts=-1234", access_token=self.admin_user_tok,
         )
 
@@ -107,7 +106,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # until_ts smaller from_ts
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?from_ts=10&until_ts=5",
             access_token=self.admin_user_tok,
@@ -117,7 +116,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # empty search term
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?search_term=", access_token=self.admin_user_tok,
         )
 
@@ -125,7 +124,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?dir=bar", access_token=self.admin_user_tok,
         )
 
@@ -138,7 +137,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         self._create_users_with_media(10, 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
         )
 
@@ -154,7 +153,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         self._create_users_with_media(20, 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5", access_token=self.admin_user_tok,
         )
 
@@ -170,7 +169,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         """
         self._create_users_with_media(20, 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
         )
 
@@ -190,7 +189,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of results is the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
         )
 
@@ -201,7 +200,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
         )
 
@@ -212,7 +211,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
         )
 
@@ -223,7 +222,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
 
         # Set `from` to value of `next_token` for request remaining entries
         # Check `next_token` does not appear
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=19", access_token=self.admin_user_tok,
         )
 
@@ -238,9 +237,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         if users have no media created
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -316,15 +313,13 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         ts1 = self.clock.time_msec()
 
         # list all media when filter is not set
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
 
         # filter media starting at `ts1` after creating first media
         # result is 0
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from_ts=%s" % (ts1,), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -337,7 +332,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_media(self.other_user_tok, 3)
 
         # filter media between `ts1` and `ts2`
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?from_ts=%s&until_ts=%s" % (ts1, ts2),
             access_token=self.admin_user_tok,
@@ -346,7 +341,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["users"][0]["media_count"], 3)
 
         # filter media until `ts2` and earlier
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?until_ts=%s" % (ts2,), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -356,14 +351,12 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self._create_users_with_media(20, 1)
 
         # check without filter get all users
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(channel.json_body["total"], 20)
 
         # filter user 1 and 10-19 by `user_id`
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?search_term=foo_user_1",
             access_token=self.admin_user_tok,
@@ -372,7 +365,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["total"], 11)
 
         # filter on this user in `displayname`
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             self.url + "?search_term=bar_user_10",
             access_token=self.admin_user_tok,
@@ -382,7 +375,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["total"], 1)
 
         # filter and get empty result
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?search_term=foobar", access_token=self.admin_user_tok,
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -447,7 +440,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
         url = self.url + "?order_by=%s" % (order_type,)
         if dir is not None and dir in ("b", "f"):
             url += "&dir=%s" % (dir,)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", url.encode("ascii"), access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 54d46f4bd3..ee05ee60bc 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,14 +18,17 @@ import hmac
 import json
 import urllib.parse
 from binascii import unhexlify
+from typing import Optional
 
 from mock import Mock
 
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
+from synapse.api.room_versions import RoomVersions
 from synapse.rest.client.v1 import login, logout, profile, room
 from synapse.rest.client.v2_alpha import devices, sync
+from synapse.types import JsonDict
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -70,7 +73,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
         self.hs.config.registration_shared_secret = None
 
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
@@ -87,7 +90,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         self.hs.get_secrets = Mock(return_value=secrets)
 
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.json_body, {"nonce": "abcd"})
 
@@ -96,14 +99,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         Calling GET on the endpoint will return a randomised nonce, which will
         only last for SALT_TIMEOUT (60s).
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         # 59 seconds
         self.reactor.advance(59)
 
         body = json.dumps({"nonce": nonce})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("username must be specified", channel.json_body["error"])
@@ -111,7 +114,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         # 61 seconds
         self.reactor.advance(2)
 
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -120,7 +123,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
         Only the provided nonce can be used, as it's checked in the MAC.
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -136,7 +139,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("HMAC incorrect", channel.json_body["error"])
@@ -146,7 +149,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         When the correct nonce is provided, and the right key is provided, the
         user is registered.
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -165,7 +168,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -174,7 +177,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
         A valid unrecognised nonce.
         """
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -190,13 +193,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
 
         # Now, try and reuse it
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -209,7 +212,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
 
         def nonce():
-            request, channel = self.make_request("GET", self.url)
+            channel = self.make_request("GET", self.url)
             return channel.json_body["nonce"]
 
         #
@@ -218,7 +221,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -229,28 +232,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({"nonce": nonce()})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("username must be specified", channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": 1234})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid username", channel.json_body["error"])
@@ -261,28 +264,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
 
         # Must be present
         body = json.dumps({"nonce": nonce(), "username": "a"})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("password must be specified", channel.json_body["error"])
 
         # Must be a string
         body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Must not have null bytes
         body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Super long
         body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid password", channel.json_body["error"])
@@ -300,7 +303,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "user_type": "invalid",
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -311,7 +314,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         """
 
         # set no displayname
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -321,17 +324,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         body = json.dumps(
             {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob1:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob1:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob1:test/displayname")
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("bob1", channel.json_body["displayname"])
 
         # displayname is None
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -347,17 +350,17 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob2:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob2:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob2:test/displayname")
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("bob2", channel.json_body["displayname"])
 
         # displayname is empty
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -373,16 +376,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob3:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob3:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob3:test/displayname")
         self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
 
         # set displayname
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -398,12 +401,12 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob4:test", channel.json_body["user_id"])
 
-        request, channel = self.make_request("GET", "/profile/@bob4:test/displayname")
+        channel = self.make_request("GET", "/profile/@bob4:test/displayname")
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("Bob's Name", channel.json_body["displayname"])
 
@@ -429,7 +432,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         )
 
         # Register new user with admin API
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -448,7 +451,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
                 "mac": want_mac,
             }
         )
-        request, channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body.encode("utf8"))
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -466,23 +469,34 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        self.register_user("user1", "pass1", admin=False)
-        self.register_user("user2", "pass2", admin=False)
-
     def test_no_auth(self):
         """
         Try to list users without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        self._create_users(1)
+        other_user_token = self.login("user1", "pass1")
+
+        channel = self.make_request("GET", self.url, access_token=other_user_token)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_all_users(self):
         """
         List all users, including deactivated users.
         """
-        request, channel = self.make_request(
+        self._create_users(2)
+
+        channel = self.make_request(
             "GET",
             self.url + "?deactivated=true",
             b"{}",
@@ -493,6 +507,449 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(3, len(channel.json_body["users"]))
         self.assertEqual(3, channel.json_body["total"])
 
+        # Check that all fields are available
+        self._check_fields(channel.json_body["users"])
+
+    def test_search_term(self):
+        """Test that searching for a users works correctly"""
+
+        def _search_test(
+            expected_user_id: Optional[str],
+            search_term: str,
+            search_field: Optional[str] = "name",
+            expected_http_code: Optional[int] = 200,
+        ):
+            """Search for a user and check that the returned user's id is a match
+
+            Args:
+                expected_user_id: The user_id expected to be returned by the API. Set
+                    to None to expect zero results for the search
+                search_term: The term to search for user names with
+                search_field: Field which is to request: `name` or `user_id`
+                expected_http_code: The expected http code for the request
+            """
+            url = self.url + "?%s=%s" % (search_field, search_term,)
+            channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+            if expected_http_code != 200:
+                return
+
+            # Check that users were returned
+            self.assertTrue("users" in channel.json_body)
+            self._check_fields(channel.json_body["users"])
+            users = channel.json_body["users"]
+
+            # Check that the expected number of users were returned
+            expected_user_count = 1 if expected_user_id else 0
+            self.assertEqual(len(users), expected_user_count)
+            self.assertEqual(channel.json_body["total"], expected_user_count)
+
+            if expected_user_id:
+                # Check that the first returned user id is correct
+                u = users[0]
+                self.assertEqual(expected_user_id, u["name"])
+
+        self._create_users(2)
+
+        user1 = "@user1:test"
+        user2 = "@user2:test"
+
+        # Perform search tests
+        _search_test(user1, "er1")
+        _search_test(user1, "me 1")
+
+        _search_test(user2, "er2")
+        _search_test(user2, "me 2")
+
+        _search_test(user1, "er1", "user_id")
+        _search_test(user2, "er2", "user_id")
+
+        # Test case insensitive
+        _search_test(user1, "ER1")
+        _search_test(user1, "NAME 1")
+
+        _search_test(user2, "ER2")
+        _search_test(user2, "NAME 2")
+
+        _search_test(user1, "ER1", "user_id")
+        _search_test(user2, "ER2", "user_id")
+
+        _search_test(None, "foo")
+        _search_test(None, "bar")
+
+        _search_test(None, "foo", "user_id")
+        _search_test(None, "bar", "user_id")
+
+    def test_invalid_parameter(self):
+        """
+        If parameters are invalid, an error is returned.
+        """
+
+        # negative limit
+        channel = self.make_request(
+            "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # negative from
+        channel = self.make_request(
+            "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+
+        # invalid guests
+        channel = self.make_request(
+            "GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+        # invalid deactivated
+        channel = self.make_request(
+            "GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+
+    def test_limit(self):
+        """
+        Testing list of users with limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 5)
+        self.assertEqual(channel.json_body["next_token"], "5")
+        self._check_fields(channel.json_body["users"])
+
+    def test_from(self):
+        """
+        Testing list of users with a defined starting point (from)
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?from=5", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 15)
+        self.assertNotIn("next_token", channel.json_body)
+        self._check_fields(channel.json_body["users"])
+
+    def test_limit_and_from(self):
+        """
+        Testing list of users with a defined starting point and limit
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        channel = self.make_request(
+            "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(channel.json_body["next_token"], "15")
+        self.assertEqual(len(channel.json_body["users"]), 10)
+        self._check_fields(channel.json_body["users"])
+
+    def test_next_token(self):
+        """
+        Testing that `next_token` appears at the right place
+        """
+
+        number_users = 20
+        # Create one less user (since there's already an admin user).
+        self._create_users(number_users - 1)
+
+        #  `next_token` does not appear
+        # Number of results is the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does not appear
+        # Number of max results is larger than the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), number_users)
+        self.assertNotIn("next_token", channel.json_body)
+
+        #  `next_token` does appear
+        # Number of max results is smaller than the number of entries
+        channel = self.make_request(
+            "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 19)
+        self.assertEqual(channel.json_body["next_token"], "19")
+
+        # Check
+        # Set `from` to value of `next_token` for request remaining entries
+        #  `next_token` does not appear
+        channel = self.make_request(
+            "GET", self.url + "?from=19", access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["total"], number_users)
+        self.assertEqual(len(channel.json_body["users"]), 1)
+        self.assertNotIn("next_token", channel.json_body)
+
+    def _check_fields(self, content: JsonDict):
+        """Checks that the expected user attributes are present in content
+        Args:
+            content: List that is checked for content
+        """
+        for u in content:
+            self.assertIn("name", u)
+            self.assertIn("is_guest", u)
+            self.assertIn("admin", u)
+            self.assertIn("user_type", u)
+            self.assertIn("deactivated", u)
+            self.assertIn("displayname", u)
+            self.assertIn("avatar_url", u)
+
+    def _create_users(self, number_users: int):
+        """
+        Create a number of users
+        Args:
+            number_users: Number of users to be created
+        """
+        for i in range(1, number_users + 1):
+            self.register_user(
+                "user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
+            )
+
+
+class DeactivateAccountTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass", displayname="User1")
+        self.other_user_token = self.login("user", "pass")
+        self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
+            self.other_user
+        )
+        self.url = "/_synapse/admin/v1/deactivate/%s" % urllib.parse.quote(
+            self.other_user
+        )
+
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+    def test_no_auth(self):
+        """
+        Try to deactivate users without authentication.
+        """
+        channel = self.make_request("POST", self.url, b"{}")
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        url = "/_synapse/admin/v1/deactivate/@bob:test"
+
+        channel = self.make_request("POST", url, access_token=self.other_user_token)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+        channel = self.make_request(
+            "POST", url, access_token=self.other_user_token, content=b"{}",
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("You are not a server admin", channel.json_body["error"])
+
+    def test_user_does_not_exist(self):
+        """
+        Tests that deactivation for a user that does not exist returns a 404
+        """
+
+        channel = self.make_request(
+            "POST",
+            "/_synapse/admin/v1/deactivate/@unknown_person:test",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_erase_is_not_bool(self):
+        """
+        If parameter `erase` is not boolean, return an error
+        """
+        body = json.dumps({"erase": "False"})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that deactivation for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/deactivate/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual("Can only deactivate local users", channel.json_body["error"])
+
+    def test_deactivate_user_erase_true(self):
+        """
+        Test deactivating an user and set `erase` to `true`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": True})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertIsNone(channel.json_body["avatar_url"])
+        self.assertIsNone(channel.json_body["displayname"])
+
+        self._is_erased("@user:test", True)
+
+    def test_deactivate_user_erase_false(self):
+        """
+        Test deactivating an user and set `erase` to `false`
+        """
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        # Deactivate user
+        body = json.dumps({"erase": False})
+
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User1", channel.json_body["displayname"])
+
+        self._is_erased("@user:test", False)
+
+    def _is_erased(self, user_id: str, expect: bool) -> None:
+        """Assert that the user is erased or not
+        """
+        d = self.store.is_user_erased(user_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertFalse(self.get_success(d))
+
 
 class UserRestTestCase(unittest.HomeserverTestCase):
 
@@ -508,7 +965,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        self.other_user = self.register_user("user", "pass")
+        self.other_user = self.register_user("user", "pass", displayname="User")
         self.other_user_token = self.login("user", "pass")
         self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
             self.other_user
@@ -520,14 +977,12 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v2/users/@bob:test"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.other_user_token,
-        )
+        channel = self.make_request("GET", url, access_token=self.other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("You are not a server admin", channel.json_body["error"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, access_token=self.other_user_token, content=b"{}",
         )
 
@@ -539,7 +994,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_synapse/admin/v2/users/@unknown_person:test",
             access_token=self.admin_user_tok,
@@ -561,11 +1016,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
                 "admin": True,
                 "displayname": "Bob's name",
                 "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
-                "avatar_url": None,
+                "avatar_url": "mxc://fibble/wibble",
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -578,11 +1033,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
         self.assertEqual(True, channel.json_body["admin"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -592,6 +1046,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(True, channel.json_body["admin"])
         self.assertEqual(False, channel.json_body["is_guest"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     def test_create_user(self):
         """
@@ -606,10 +1061,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
                 "admin": False,
                 "displayname": "Bob's name",
                 "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+                "avatar_url": "mxc://fibble/wibble",
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -622,11 +1078,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
         self.assertEqual(False, channel.json_body["admin"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -636,6 +1091,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(False, channel.json_body["admin"])
         self.assertEqual(False, channel.json_body["is_guest"])
         self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     @override_config(
         {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -651,9 +1107,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Sync to set admin user to active
         # before limit of monthly active users is reached
-        request, channel = self.make_request(
-            "GET", "/sync", access_token=self.admin_user_tok
-        )
+        channel = self.make_request("GET", "/sync", access_token=self.admin_user_tok)
 
         if channel.code != 200:
             raise HttpResponseException(
@@ -676,7 +1130,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = json.dumps({"password": "abc123", "admin": False})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -715,7 +1169,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = json.dumps({"password": "abc123", "admin": False})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -752,7 +1206,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -769,7 +1223,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual("@bob:test", pushers[0]["user_name"])
+        self.assertEqual("@bob:test", pushers[0].user_name)
 
     @override_config(
         {
@@ -796,7 +1250,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             }
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -822,7 +1276,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Change password
         body = json.dumps({"password": "hahaha"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -839,7 +1293,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Modify user
         body = json.dumps({"displayname": "foobar"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -851,7 +1305,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("foobar", channel.json_body["displayname"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -869,7 +1323,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             {"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -882,7 +1336,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -896,10 +1350,30 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Test deactivating another user.
         """
 
+        # set attributes for user
+        self.get_success(
+            self.store.set_profile_avatar_url("user", "mxc://servername/mediaid")
+        )
+        self.get_success(
+            self.store.user_add_threepid("@user:test", "email", "foo@bar.com", 0, 0)
+        )
+
+        # Get user
+        channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
+
         # Deactivate user
         body = json.dumps({"deactivated": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -909,16 +1383,70 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
         # the user is deactivated, the threepid will be deleted
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
+        self.assertEqual("User", channel.json_body["displayname"])
+
+    @override_config({"user_directory": {"enabled": True, "search_all_users": True}})
+    def test_change_name_deactivate_user_user_directory(self):
+        """
+        Test change profile information of a deactivated user and
+        check that it does not appear in user directory
+        """
+
+        # is in user directory
+        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        self.assertTrue(profile["display_name"] == "User")
+
+        # Deactivate user
+        body = json.dumps({"deactivated": True})
+
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+
+        # is not in user directory
+        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        self.assertTrue(profile is None)
+
+        # Set new displayname user
+        body = json.dumps({"displayname": "Foobar"})
+
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=body.encode(encoding="utf_8"),
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertEqual("Foobar", channel.json_body["displayname"])
+
+        # is not in user directory
+        profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        self.assertTrue(profile is None)
 
     def test_reactivate_user(self):
         """
@@ -926,7 +1454,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Deactivate the user.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -939,7 +1467,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self._is_erased("@user:test", True)
 
         # Attempt to reactivate the user (without a password).
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -948,7 +1476,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
 
         # Reactivate the user.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -959,7 +1487,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -976,7 +1504,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Set a user as an admin
         body = json.dumps({"admin": True})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
@@ -988,7 +1516,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(True, channel.json_body["admin"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_other_user, access_token=self.admin_user_tok,
         )
 
@@ -1006,7 +1534,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = json.dumps({"password": "abc123"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -1018,9 +1546,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("bob", channel.json_body["displayname"])
 
         # Get user
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1030,7 +1556,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Change password (and use a str for deactivate instead of a bool)
         body = json.dumps({"password": "abc123", "deactivated": "false"})  # oops!
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
@@ -1040,9 +1566,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
 
         # Check user is not deactivated
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
@@ -1070,8 +1594,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -1084,7 +1606,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list rooms of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1095,37 +1617,33 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_user_does_not_exist(self):
         """
-        Tests that a lookup for a user that does not exist returns a 404
+        Tests that a lookup for a user that does not exist returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/joined_rooms"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_user_is_not_local(self):
         """
-        Tests that a lookup for a user that is not a local returns a 400
+        Tests that a lookup for a user that is not a local and participates in no conversation returns an empty list
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/joined_rooms"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(0, channel.json_body["total"])
+        self.assertEqual(0, len(channel.json_body["joined_rooms"]))
 
     def test_no_memberships(self):
         """
@@ -1133,9 +1651,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         if user has no memberships
         """
         # Get rooms
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1152,14 +1668,55 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             self.helper.create_room_as(self.other_user, tok=other_user_tok)
 
         # Get rooms
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_rooms, channel.json_body["total"])
         self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
 
+    def test_get_rooms_with_nonlocal_user(self):
+        """
+        Tests that a normal lookup for rooms is successful with a non-local user
+        """
+
+        other_user_tok = self.login("user", "pass")
+        event_builder_factory = self.hs.get_event_builder_factory()
+        event_creation_handler = self.hs.get_event_creation_handler()
+        storage = self.hs.get_storage()
+
+        # Create two rooms, one with a local user only and one with both a local
+        # and remote user.
+        self.helper.create_room_as(self.other_user, tok=other_user_tok)
+        local_and_remote_room_id = self.helper.create_room_as(
+            self.other_user, tok=other_user_tok
+        )
+
+        # Add a remote user to the room.
+        builder = event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": "m.room.member",
+                "sender": "@joiner:remote_hs",
+                "state_key": "@joiner:remote_hs",
+                "room_id": local_and_remote_room_id,
+                "content": {"membership": "join"},
+            },
+        )
+
+        event, context = self.get_success(
+            event_creation_handler.create_new_client_event(builder)
+        )
+
+        self.get_success(storage.persistence.persist_event(event, context))
+
+        # Now get rooms
+        url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual(1, channel.json_body["total"])
+        self.assertEqual([local_and_remote_room_id], channel.json_body["joined_rooms"])
+
 
 class PushersRestTestCase(unittest.HomeserverTestCase):
 
@@ -1183,7 +1740,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list pushers of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1194,9 +1751,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1206,9 +1761,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1219,9 +1772,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1232,9 +1783,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Get pushers
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1256,14 +1805,12 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "example.com"},
+                data={"url": "https://example.com/_matrix/push/v1/notify"},
             )
         )
 
         # Get pushers
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(1, channel.json_body["total"])
@@ -1287,7 +1834,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
         self.media_repo = hs.get_media_repository_resource()
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -1302,7 +1848,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         Try to list media of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1313,9 +1859,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url, access_token=other_user_token,)
 
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@@ -1325,9 +1869,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/media"
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@@ -1338,9 +1880,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         """
         url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media"
 
-        request, channel = self.make_request(
-            "GET", url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url, access_token=self.admin_user_tok,)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only lookup local users", channel.json_body["error"])
@@ -1354,7 +1894,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=5", access_token=self.admin_user_tok,
         )
 
@@ -1373,7 +1913,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5", access_token=self.admin_user_tok,
         )
 
@@ -1392,7 +1932,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
         )
 
@@ -1407,7 +1947,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Testing that a negative limit parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
         )
 
@@ -1419,7 +1959,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         Testing that a negative from parameter returns a 400
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=-5", access_token=self.admin_user_tok,
         )
 
@@ -1437,7 +1977,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of results is the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=20", access_token=self.admin_user_tok,
         )
 
@@ -1448,7 +1988,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does not appear
         # Number of max results is larger than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=21", access_token=self.admin_user_tok,
         )
 
@@ -1459,7 +1999,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
 
         #  `next_token` does appear
         # Number of max results is smaller than the number of entries
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?limit=19", access_token=self.admin_user_tok,
         )
 
@@ -1471,7 +2011,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         # Check
         # Set `from` to value of `next_token` for request remaining entries
         #  `next_token` does not appear
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url + "?from=19", access_token=self.admin_user_tok,
         )
 
@@ -1486,9 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         if user has no media created
         """
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(0, channel.json_body["total"])
@@ -1503,9 +2041,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         self._create_media(other_user_tok, number_media)
 
-        request, channel = self.make_request(
-            "GET", self.url, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url, access_token=self.admin_user_tok,)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(number_media, channel.json_body["total"])
@@ -1571,7 +2107,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         )
 
     def _get_token(self) -> str:
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, b"{}", access_token=self.admin_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1580,7 +2116,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
     def test_no_auth(self):
         """Try to login as a user without authentication.
         """
-        request, channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request("POST", self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -1588,7 +2124,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
     def test_not_admin(self):
         """Try to login as a user as a non-admin user.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", self.url, b"{}", access_token=self.other_user_tok
         )
 
@@ -1616,7 +2152,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         self._get_token()
 
         # Check that we don't see a new device in our devices list
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1631,25 +2167,19 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         puppet_token = self._get_token()
 
         # Test that we can successfully make a request
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Logout with the puppet token
-        request, channel = self.make_request(
-            "POST", "logout", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # The puppet token should no longer work
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
 
         # .. but the real user's tokens should still work
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1662,25 +2192,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         puppet_token = self._get_token()
 
         # Test that we can successfully make a request
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Logout all with the real user token
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "logout/all", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # The puppet token should still work
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # .. but the real user's tokens shouldn't
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
@@ -1693,25 +2219,21 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         puppet_token = self._get_token()
 
         # Test that we can successfully make a request
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Logout all with the admin user token
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "logout/all", b"{}", access_token=self.admin_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # The puppet token should no longer work
-        request, channel = self.make_request(
-            "GET", "devices", b"{}", access_token=puppet_token
-        )
+        channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
 
         # .. but the real user's tokens should still work
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@@ -1778,8 +2300,6 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -1793,11 +2313,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         Try to get information of an user without authentication.
         """
-        request, channel = self.make_request("GET", self.url1, b"{}")
+        channel = self.make_request("GET", self.url1, b"{}")
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request("GET", self.url2, b"{}")
+        channel = self.make_request("GET", self.url2, b"{}")
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
@@ -1808,15 +2328,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         self.register_user("user2", "pass")
         other_user2_token = self.login("user2", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url1, access_token=other_user2_token,
-        )
+        channel = self.make_request("GET", self.url1, access_token=other_user2_token,)
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        request, channel = self.make_request(
-            "GET", self.url2, access_token=other_user2_token,
-        )
+        channel = self.make_request("GET", self.url2, access_token=other_user2_token,)
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
@@ -1827,15 +2343,11 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
         url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
 
-        request, channel = self.make_request(
-            "GET", url1, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url1, access_token=self.admin_user_tok,)
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only whois a local user", channel.json_body["error"])
 
-        request, channel = self.make_request(
-            "GET", url2, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", url2, access_token=self.admin_user_tok,)
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Can only whois a local user", channel.json_body["error"])
 
@@ -1843,16 +2355,12 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         The lookup should succeed for an admin.
         """
-        request, channel = self.make_request(
-            "GET", self.url1, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url1, access_token=self.admin_user_tok,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
 
-        request, channel = self.make_request(
-            "GET", self.url2, access_token=self.admin_user_tok,
-        )
+        channel = self.make_request("GET", self.url2, access_token=self.admin_user_tok,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
@@ -1863,16 +2371,76 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         other_user_token = self.login("user", "pass")
 
-        request, channel = self.make_request(
-            "GET", self.url1, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url1, access_token=other_user_token,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
 
-        request, channel = self.make_request(
-            "GET", self.url2, access_token=other_user_token,
-        )
+        channel = self.make_request("GET", self.url2, access_token=other_user_token,)
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual(self.other_user, channel.json_body["user_id"])
         self.assertIn("devices", channel.json_body)
+
+
+class ShadowBanRestTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+
+        self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
+            self.other_user
+        )
+
+    def test_no_auth(self):
+        """
+        Try to get information of an user without authentication.
+        """
+        channel = self.make_request("POST", self.url)
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_not_admin(self):
+        """
+        If the user is not a server admin, an error is returned.
+        """
+        other_user_token = self.login("user", "pass")
+
+        channel = self.make_request("POST", self.url, access_token=other_user_token)
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self):
+        """
+        Tests that shadow-banning for a user that is not a local returns a 400
+        """
+        url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
+
+        channel = self.make_request("POST", url, access_token=self.admin_user_tok)
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+
+    def test_success(self):
+        """
+        Shadow-banning should succeed for an admin.
+        """
+        # The user starts off as not shadow-banned.
+        other_user_token = self.login("user", "pass")
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertFalse(result.shadow_banned)
+
+        channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual({}, channel.json_body)
+
+        # Ensure the user is shadow-banned (and the cache was cleared).
+        result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        self.assertTrue(result.shadow_banned)
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index e2e6a5e16d..c74693e9b2 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -61,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
     def test_render_public_consent(self):
         """You can observe the terms form without specifying a user"""
         resource = consent_resource.ConsentResource(self.hs)
-        request, channel = make_request(
+        channel = make_request(
             self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
         )
         self.assertEqual(channel.code, 200)
@@ -82,7 +82,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
             uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
             + "&u=user"
         )
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "GET",
@@ -97,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(consented, "False")
 
         # POST to the consent page, saying we've agreed
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "POST",
@@ -109,7 +109,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
 
         # Fetch the consent page, to get the consent version -- it should have
         # changed
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "GET",
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index a1ccc4ee9a..56937dcd2e 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -93,7 +93,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
     def get_event(self, room_id, event_id, expected_code=200):
         url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
 
-        request, channel = self.make_request("GET", url)
+        channel = self.make_request("GET", url)
 
         self.assertEqual(channel.code, expected_code, channel.result)
 
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index 259c6a1985..c0a9fc6925 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -43,9 +43,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
-        request, channel = self.make_request(
-            b"POST", "/createRoom", b"{}", access_token=tok
-        )
+        channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
         room_id = channel.json_body["room_id"]
 
@@ -56,7 +54,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         }
         request_data = json.dumps(params)
         request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", request_url, request_data, access_token=tok
         )
         self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index c1f516cc93..f0707646bb 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -69,16 +69,12 @@ class RedactionsTestCase(HomeserverTestCase):
         """
         path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
 
-        request, channel = self.make_request(
-            "POST", path, content={}, access_token=access_token
-        )
+        channel = self.make_request("POST", path, content={}, access_token=access_token)
         self.assertEqual(int(channel.result["code"]), expect_code)
         return channel.json_body
 
     def _sync_room_timeline(self, access_token, room_id):
-        request, channel = self.make_request(
-            "GET", "sync", access_token=self.mod_access_token
-        )
+        channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
         self.assertEqual(channel.result["code"], b"200")
         room_sync = channel.json_body["rooms"]["join"][room_id]
         return room_sync["timeline"]["events"]
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f56b5d9231..31dc832fd5 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -325,7 +325,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
     def get_event(self, room_id, event_id, expected_code=200):
         url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
 
-        request, channel = self.make_request("GET", url, access_token=self.token)
+        channel = self.make_request("GET", url, access_token=self.token)
 
         self.assertEqual(channel.code, expected_code, channel.result)
 
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 94dcfb9f7c..0ebdf1415b 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -18,6 +18,7 @@ import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+from synapse.types import UserID
 
 from tests import unittest
 
@@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
         self.store = self.hs.get_datastore()
 
         self.get_success(
-            self.store.db_pool.simple_update(
-                table="users",
-                keyvalues={"name": self.banned_user_id},
-                updatevalues={"shadow_banned": True},
-                desc="shadow_ban",
-            )
+            self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
         )
 
         self.other_user_id = self.register_user("otheruser", "pass")
@@ -89,7 +85,7 @@ class RoomTestCase(_ShadowBannedBase):
         )
 
         # Inviting the user completes successfully.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/rooms/%s/invite" % (room_id,),
             {"id_server": "test", "medium": "email", "address": "test@test.test"},
@@ -103,7 +99,7 @@ class RoomTestCase(_ShadowBannedBase):
     def test_create_room(self):
         """Invitations during a room creation should be discarded, but the room still gets created."""
         # The room creation is successful.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/createRoom",
             {"visibility": "public", "invite": [self.other_user_id]},
@@ -158,7 +154,7 @@ class RoomTestCase(_ShadowBannedBase):
             self.banned_user_id, tok=self.banned_access_token
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
             {"new_version": "6"},
@@ -183,7 +179,7 @@ class RoomTestCase(_ShadowBannedBase):
             self.banned_user_id, tok=self.banned_access_token
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
             {"typing": True, "timeout": 30000},
@@ -198,7 +194,7 @@ class RoomTestCase(_ShadowBannedBase):
         # The other user can join and send typing events.
         self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
             {"typing": True, "timeout": 30000},
@@ -244,7 +240,7 @@ class ProfileTestCase(_ShadowBannedBase):
         )
 
         # The update should succeed.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
             {"displayname": new_display_name},
@@ -254,7 +250,7 @@ class ProfileTestCase(_ShadowBannedBase):
         self.assertEqual(channel.json_body, {})
 
         # The user's display name should be updated.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/profile/%s/displayname" % (self.banned_user_id,)
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -282,7 +278,7 @@ class ProfileTestCase(_ShadowBannedBase):
         )
 
         # The update should succeed.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
             % (room_id, self.banned_user_id),
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 0e96697f9b..227fffab58 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -86,7 +86,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         callback = Mock(spec=[], side_effect=check)
         current_rules_module().check_event_allowed = callback
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
             {},
@@ -104,7 +104,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             self.assertEqual(ev.type, k[0])
             self.assertEqual(ev.state_key, k[1])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
             {},
@@ -123,7 +123,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         current_rules_module().check_event_allowed = check
 
         # now send the event
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
             {"x": "x"},
@@ -142,7 +142,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         current_rules_module().check_event_allowed = check
 
         # now send the event
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
             {"x": "x"},
@@ -152,7 +152,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         event_id = channel.json_body["event_id"]
 
         # ... and check that it got modified
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
             access_token=self.tok,
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 7a2c653df8..edd1d184f8 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -91,7 +91,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         # that we can make sure that the check is done on the whole alias.
         data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, 400, channel.result)
@@ -104,7 +104,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         # as cautious as possible here.
         data = {"room_alias_name": random_string(5)}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -118,7 +118,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         data = {"aliases": [self.random_alias(alias_length)]}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
@@ -128,7 +128,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.user_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 12a93f5687..0a5ca317ea 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -63,13 +63,13 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         # implementation is now part of the r0 implementation, the newer
         # behaviour is used instead to be consistent with the r0 spec.
         # see issue #2602
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events?access_token=%s" % ("invalid" + self.token,)
         )
         self.assertEquals(channel.code, 401, msg=channel.result)
 
         # valid token, expect content
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
         self.assertEquals(channel.code, 200, msg=channel.result)
@@ -87,7 +87,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         )
 
         # valid token, expect content
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
         self.assertEquals(channel.code, 200, msg=channel.result)
@@ -149,7 +149,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(self.room_id, tok=self.token)
         event_id = resp["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/events/" + event_id, access_token=self.token,
         )
         self.assertEquals(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 176ddf7ec9..bfcb786af8 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,23 +1,83 @@
-import json
+# -*- coding: utf-8 -*-
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import time
 import urllib.parse
+from typing import Any, Dict, Union
+from urllib.parse import urlencode
 
 from mock import Mock
 
-import jwt
+import pymacaroons
+
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.appservice import ApplicationService
 from synapse.rest.client.v1 import login, logout
 from synapse.rest.client.v2_alpha import devices, register
 from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import create_requester
 
 from tests import unittest
-from tests.unittest import override_config
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.handlers.test_saml import has_saml2
+from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.test_utils.html_parsers import TestHtmlParser
+from tests.unittest import HomeserverTestCase, override_config, skip_unless
+
+try:
+    import jwt
+
+    HAS_JWT = True
+except ImportError:
+    HAS_JWT = False
+
+
+# public_base_url used in some tests
+BASE_URL = "https://synapse/"
+
+# CAS server used in some tests
+CAS_SERVER = "https://fake.test"
+
+# just enough to tell pysaml2 where to redirect to
+SAML_SERVER = "https://test.saml.server/idp/sso"
+TEST_SAML_METADATA = """
+<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
+  <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
+      <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
+  </md:IDPSSODescriptor>
+</md:EntityDescriptor>
+""" % {
+    "SAML_SERVER": SAML_SERVER,
+}
 
 LOGIN_URL = b"/_matrix/client/r0/login"
 TEST_URL = b"/_matrix/client/r0/account/whoami"
 
+# a (valid) url with some annoying characters in.  %3D is =, %26 is &, %2B is +
+TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+# the query params in TEST_CLIENT_REDIRECT_URL
+EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
@@ -63,7 +123,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
                 "password": "monkey",
             }
-            request, channel = self.make_request(b"POST", LOGIN_URL, params)
+            channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -82,7 +142,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
             "password": "monkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
@@ -108,7 +168,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "monkey",
             }
-            request, channel = self.make_request(b"POST", LOGIN_URL, params)
+            channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -127,7 +187,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "monkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
@@ -153,7 +213,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "notamonkey",
             }
-            request, channel = self.make_request(b"POST", LOGIN_URL, params)
+            channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -172,7 +232,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "notamonkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
@@ -181,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self.register_user("kermit", "monkey")
 
         # we shouldn't be able to make requests without an access token
-        request, channel = self.make_request(b"GET", TEST_URL)
+        channel = self.make_request(b"GET", TEST_URL)
         self.assertEquals(channel.result["code"], b"401", channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
 
@@ -191,25 +251,21 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "monkey",
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.code, 200, channel.result)
         access_token = channel.json_body["access_token"]
         device_id = channel.json_body["device_id"]
 
         # we should now be able to make requests with the access token
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
@@ -223,9 +279,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # more requests with the expired token should still return a soft-logout
         self.reactor.advance(3600)
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
@@ -233,16 +287,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # ... but if we delete that device, it will be a proper logout
         self._delete_device(access_token_2, "kermit", "monkey", device_id)
 
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], False)
 
     def _delete_device(self, access_token, user_id, password, device_id):
         """Perform the UI-Auth to delete a device"""
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"DELETE", "devices/" + device_id, access_token=access_token
         )
         self.assertEquals(channel.code, 401, channel.result)
@@ -262,7 +314,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "session": channel.json_body["session"],
         }
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"DELETE",
             "devices/" + device_id,
             access_token=access_token,
@@ -278,26 +330,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         access_token = self.login("kermit", "monkey")
 
         # we should now be able to make requests with the access token
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
 
         # Now try to hard logout this session
-        request, channel = self.make_request(
-            b"POST", "/logout", access_token=access_token
-        )
+        channel = self.make_request(b"POST", "/logout", access_token=access_token)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     @override_config({"session_lifetime": "24h"})
@@ -308,29 +354,313 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         access_token = self.login("kermit", "monkey")
 
         # we should now be able to make requests with the access token
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
-        request, channel = self.make_request(
-            b"GET", TEST_URL, access_token=access_token
-        )
+        channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
         self.assertEquals(channel.code, 401, channel.result)
         self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEquals(channel.json_body["soft_logout"], True)
 
         # Now try to hard log out all of the user's sessions
-        request, channel = self.make_request(
-            b"POST", "/logout/all", access_token=access_token
-        )
+        channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
 
+@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
+class MultiSSOTestCase(unittest.HomeserverTestCase):
+    """Tests for homeservers with multiple SSO providers enabled"""
+
+    servlets = [
+        login.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+
+        config["public_baseurl"] = BASE_URL
+
+        config["cas_config"] = {
+            "enabled": True,
+            "server_url": CAS_SERVER,
+            "service_url": "https://matrix.goodserver.com:8448",
+        }
+
+        config["saml2_config"] = {
+            "sp_config": {
+                "metadata": {"inline": [TEST_SAML_METADATA]},
+                # use the XMLSecurity backend to avoid relying on xmlsec1
+                "crypto_backend": "XMLSecurity",
+            },
+        }
+
+        # default OIDC provider
+        config["oidc_config"] = TEST_OIDC_CONFIG
+
+        # additional OIDC providers
+        config["oidc_providers"] = [
+            {
+                "idp_id": "idp1",
+                "idp_name": "IDP1",
+                "discover": False,
+                "issuer": "https://issuer1",
+                "client_id": "test-client-id",
+                "client_secret": "test-client-secret",
+                "scopes": ["profile"],
+                "authorization_endpoint": "https://issuer1/auth",
+                "token_endpoint": "https://issuer1/token",
+                "userinfo_endpoint": "https://issuer1/userinfo",
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.sub }}"}
+                },
+            }
+        ]
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d.update(build_synapse_client_resource_tree(self.hs))
+        return d
+
+    def test_get_login_flows(self):
+        """GET /login should return password and SSO flows"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.sso"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_get_msc2858_login_flows(self):
+        """The SSO flow should include IdP info if MSC2858 is enabled"""
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # stick the flows results in a dict by type
+        flow_results = {}  # type: Dict[str, Any]
+        for f in channel.json_body["flows"]:
+            flow_type = f["type"]
+            self.assertNotIn(
+                flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+            )
+            flow_results[flow_type] = f
+
+        self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+        sso_flow = flow_results.pop("m.login.sso")
+        # we should have a set of IdPs
+        self.assertCountEqual(
+            sso_flow["org.matrix.msc2858.identity_providers"],
+            [
+                {"id": "cas", "name": "CAS"},
+                {"id": "saml", "name": "SAML"},
+                {"id": "oidc-idp1", "name": "IDP1"},
+                {"id": "oidc", "name": "OIDC"},
+            ],
+        )
+
+        # the rest of the flows are simple
+        expected_flows = [
+            {"type": "m.login.cas"},
+            {"type": "m.login.token"},
+            {"type": "m.login.password"},
+        ] + ADDITIONAL_LOGIN_FLOWS
+
+        self.assertCountEqual(flow_results.values(), expected_flows)
+
+    def test_multi_sso_redirect(self):
+        """/login/sso/redirect should redirect to an identity picker"""
+        # first hit the redirect url, which should redirect to our idp picker
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        uri = channel.headers.getRawHeaders("Location")[0]
+
+        # hitting that picker should give us some HTML
+        channel = self.make_request("GET", uri)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # parse the form to check it has fields assumed elsewhere in this class
+        p = TestHtmlParser()
+        p.feed(channel.result["body"].decode("utf-8"))
+        p.close()
+
+        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+
+        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+
+    def test_multi_sso_redirect_to_cas(self):
+        """If CAS is chosen, should redirect to the CAS server"""
+
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=cas",
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        cas_uri = channel.headers.getRawHeaders("Location")[0]
+        cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
+
+        # it should redirect us to the login page of the cas server
+        self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
+
+        # check that the redirectUrl is correctly encoded in the service param - ie, the
+        # place that CAS will redirect to
+        cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
+        service_uri = cas_uri_params["service"][0]
+        _, service_uri_query = service_uri.split("?", 1)
+        service_uri_params = urllib.parse.parse_qs(service_uri_query)
+        self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
+
+    def test_multi_sso_redirect_to_saml(self):
+        """If SAML is chosen, should redirect to the SAML server"""
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=saml",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        saml_uri = channel.headers.getRawHeaders("Location")[0]
+        saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
+
+        # it should redirect us to the login page of the SAML server
+        self.assertEqual(saml_uri_path, SAML_SERVER)
+
+        # the RelayState is used to carry the client redirect url
+        saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
+        relay_state_param = saml_uri_params["RelayState"][0]
+        self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
+
+    def test_login_via_oidc(self):
+        """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
+
+        # pick the default OIDC provider
+        channel = self.make_request(
+            "GET",
+            "/_synapse/client/pick_idp?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+            + "&idp=oidc",
+        )
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+        # ... and should have set a cookie including the redirect url
+        cookies = dict(
+            h.split(";")[0].split("=", maxsplit=1)
+            for h in channel.headers.getRawHeaders("Set-Cookie")
+        )
+
+        oidc_session_cookie = cookies["oidc_session"]
+        macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
+        self.assertEqual(
+            self._get_value_from_macaroon(macaroon, "client_redirect_url"),
+            TEST_CLIENT_REDIRECT_URL,
+        )
+
+        channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+
+        # that should serve a confirmation page
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertTrue(
+            channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
+        )
+        p = TestHtmlParser()
+        p.feed(channel.text_body)
+        p.close()
+
+        # ... which should contain our redirect link
+        self.assertEqual(len(p.links), 1)
+        path, query = p.links[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        login_token = params[2][1]
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@user1:test")
+
+    def test_multi_sso_redirect_to_unknown(self):
+        """An unknown IdP should cause a 400"""
+        channel = self.make_request(
+            "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+
+    def test_client_idp_redirect_msc2858_disabled(self):
+        """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_unknown(self):
+        """If the client tries to pick an unknown IdP, return a 404"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+        self.assertEqual(channel.code, 404, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_idp_redirect_to_oidc(self):
+        """If the client pick a known IdP, redirect to it"""
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+        )
+
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+    @staticmethod
+    def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+        prefix = key + " = "
+        for caveat in macaroon.caveats:
+            if caveat.caveat_id.startswith(prefix):
+                return caveat.caveat_id[len(prefix) :]
+        raise ValueError("No %s caveat in macaroon" % (key,))
+
+
 class CASTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -342,10 +672,12 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
 
         config = self.default_config()
+        config["public_baseurl"] = (
+            config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
+        )
         config["cas_config"] = {
             "enabled": True,
-            "server_url": "https://fake.test",
-            "service_url": "https://matrix.goodserver.com:8448",
+            "server_url": CAS_SERVER,
         }
 
         cas_user_id = "username"
@@ -402,10 +734,10 @@ class CASTestCase(unittest.HomeserverTestCase):
         cas_ticket_url = urllib.parse.urlunparse(url_parts)
 
         # Get Synapse to call the fake CAS and serve the template.
-        request, channel = self.make_request("GET", cas_ticket_url)
+        channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":
@@ -430,8 +762,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_cas_redirect_whitelisted(self):
-        """Tests that the SSO login flow serves a redirect to a whitelisted url
-        """
+        """Tests that the SSO login flow serves a redirect to a whitelisted url"""
         self._test_redirect("https://legit-site.com/")
 
     @override_config({"public_baseurl": "https://example.com"})
@@ -446,7 +777,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         )
 
         # Get Synapse to call the fake CAS and serve the template.
-        request, channel = self.make_request("GET", cas_ticket_url)
+        channel = self.make_request("GET", cas_ticket_url)
 
         self.assertEqual(channel.code, 302)
         location_headers = channel.headers.getRawHeaders("Location")
@@ -462,7 +793,9 @@ class CASTestCase(unittest.HomeserverTestCase):
 
         # Deactivate the account.
         self.get_success(
-            self.deactivate_account_handler.deactivate_account(self.user_id, False)
+            self.deactivate_account_handler.deactivate_account(
+                self.user_id, False, create_requester(self.user_id)
+            )
         )
 
         # Request the CAS ticket.
@@ -472,13 +805,14 @@ class CASTestCase(unittest.HomeserverTestCase):
         )
 
         # Get Synapse to call the fake CAS and serve the template.
-        request, channel = self.make_request("GET", cas_ticket_url)
+        channel = self.make_request("GET", cas_ticket_url)
 
         # Because the user is deactivated they are served an error template.
         self.assertEqual(channel.code, 403)
         self.assertIn(b"SSO account deactivated", channel.result["body"])
 
 
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTTestCase(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -495,14 +829,18 @@ class JWTTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = self.jwt_algorithm
         return self.hs
 
-    def jwt_encode(self, token, secret=jwt_secret):
-        return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
+        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+        result = jwt.encode(
+            payload, secret, self.jwt_algorithm
+        )  # type: Union[str, bytes]
+        if isinstance(result, bytes):
+            return result.decode("ascii")
+        return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+        channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
     def test_login_jwt_valid_registered(self):
@@ -633,8 +971,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
         )
 
     def test_login_no_token(self):
-        params = json.dumps({"type": "org.matrix.login.jwt"})
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        params = {"type": "org.matrix.login.jwt"}
+        channel = self.make_request(b"POST", LOGIN_URL, params)
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -643,6 +981,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
 # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
 # signed by the private key.
+@skip_unless(HAS_JWT, "requires jwt")
 class JWTPubKeyTestCase(unittest.HomeserverTestCase):
     servlets = [
         login.register_servlets,
@@ -700,14 +1039,16 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         self.hs.config.jwt_algorithm = "RS256"
         return self.hs
 
-    def jwt_encode(self, token, secret=jwt_privatekey):
-        return jwt.encode(token, secret, "RS256").decode("ascii")
+    def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
+        # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
+        result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str]
+        if isinstance(result, bytes):
+            return result.decode("ascii")
+        return result
 
     def jwt_login(self, *args):
-        params = json.dumps(
-            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
-        )
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+        channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
     def test_login_jwt_valid(self):
@@ -735,7 +1076,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
     ]
 
     def register_as_user(self, username):
-        request, channel = self.make_request(
+        self.make_request(
             b"POST",
             "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
             {"username": username},
@@ -776,60 +1117,56 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def test_login_appservice_user(self):
-        """Test that an appservice user can use /login
-        """
+        """Test that an appservice user can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": AS_USER},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_user_bot(self):
-        """Test that the appservice bot can use /login
-        """
+        """Test that the appservice bot can use /login"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": self.service.sender},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_login_appservice_wrong_user(self):
-        """Test that non-as users cannot login with the as token
-        """
+        """Test that non-as users cannot login with the as token"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
     def test_login_appservice_wrong_as(self):
-        """Test that as users cannot login with wrong as token
-        """
+        """Test that as users cannot login with wrong as token"""
         self.register_as_user(AS_USER)
 
         params = {
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": AS_USER},
         }
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", LOGIN_URL, params, access_token=self.another_service.token
         )
 
@@ -837,7 +1174,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_login_appservice_no_token(self):
         """Test that users must provide a token when using the appservice
-           login method
+        login method
         """
         self.register_as_user(AS_USER)
 
@@ -845,6 +1182,116 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             "type": login.LoginRestServlet.APPSERVICE_TYPE,
             "identifier": {"type": "m.id.user", "user": AS_USER},
         }
-        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        channel = self.make_request(b"POST", LOGIN_URL, params)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
+
+
+@skip_unless(HAS_OIDC, "requires OIDC")
+class UsernamePickerTestCase(HomeserverTestCase):
+    """Tests for the username picker flow of SSO login"""
+
+    servlets = [login.register_servlets]
+
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+
+        config["oidc_config"] = {}
+        config["oidc_config"].update(TEST_OIDC_CONFIG)
+        config["oidc_config"]["user_mapping_provider"] = {
+            "config": {"display_name_template": "{{ user.displayname }}"}
+        }
+
+        # whitelist this client URI so we redirect straight to it rather than
+        # serving a confirmation page
+        config["sso"] = {"client_whitelist": ["https://x"]}
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d.update(build_synapse_client_resource_tree(self.hs))
+        return d
+
+    def test_username_picker(self):
+        """Test the happy path of a username picker flow."""
+
+        # do the start of the login flow
+        channel = self.helper.auth_via_oidc(
+            {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+        )
+
+        # that should redirect to the username picker
+        self.assertEqual(channel.code, 302, channel.result)
+        picker_url = channel.headers.getRawHeaders("Location")[0]
+        self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
+
+        # ... with a username_mapping_session cookie
+        cookies = {}  # type: Dict[str,str]
+        channel.extract_cookies(cookies)
+        self.assertIn("username_mapping_session", cookies)
+        session_id = cookies["username_mapping_session"]
+
+        # introspect the sso handler a bit to check that the username mapping session
+        # looks ok.
+        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
+        self.assertIn(
+            session_id, username_mapping_sessions, "session id not found in map",
+        )
+        session = username_mapping_sessions[session_id]
+        self.assertEqual(session.remote_user_id, "tester")
+        self.assertEqual(session.display_name, "Jonny")
+        self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
+
+        # the expiry time should be about 15 minutes away
+        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
+        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+
+        # Now, submit a username to the username picker, which should serve a redirect
+        # to the completion page
+        content = urlencode({b"username": b"bobby"}).encode("utf8")
+        chan = self.make_request(
+            "POST",
+            path=picker_url,
+            content=content,
+            content_is_form=True,
+            custom_headers=[
+                ("Cookie", "username_mapping_session=" + session_id),
+                # old versions of twisted don't do form-parsing without a valid
+                # content-length header.
+                ("Content-Length", str(len(content))),
+            ],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
+        # send a request to the completion page, which should 302 to the client redirectUrl
+        chan = self.make_request(
+            "GET",
+            path=location_headers[0],
+            custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+        )
+        self.assertEqual(chan.code, 302, chan.result)
+        location_headers = chan.headers.getRawHeaders("Location")
+
+        # ensure that the returned location matches the requested redirect URL
+        path, query = location_headers[0].split("?", 1)
+        self.assertEqual(path, "https://x")
+
+        # it will have url-encoded the params properly, so we'll have to parse them
+        params = urllib.parse.parse_qsl(
+            query, keep_blank_values=True, strict_parsing=True, errors="strict"
+        )
+        self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+        self.assertEqual(params[2][0], "loginToken")
+
+        # fish the login token out of the returned redirect uri
+        login_token = params[2][1]
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token, mxid, and device id.
+        chan = self.make_request(
+            "POST", "/login", content={"type": "m.login.token", "token": login_token},
+        )
+        self.assertEqual(chan.code, 200, chan.result)
+        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 5d5c24d01c..94a5154834 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -38,7 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(
             "red",
-            http_client=None,
+            federation_http_client=None,
             federation_client=Mock(),
             presence_handler=presence_handler,
         )
@@ -53,7 +53,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
         self.hs.config.use_presence = True
 
         body = {"presence": "here", "status_msg": "beep boop"}
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/presence/%s/status" % (self.user_id,), body
         )
 
@@ -68,7 +68,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
         self.hs.config.use_presence = False
 
         body = {"presence": "here", "status_msg": "beep boop"}
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/presence/%s/status" % (self.user_id,), body
         )
 
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 383a9eafac..e59fa70baa 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -63,7 +63,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             self.addCleanup,
             "test",
-            http_client=None,
+            federation_http_client=None,
             resource_for_client=self.mock_resource,
             federation=Mock(),
             federation_client=Mock(),
@@ -189,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.owner_tok = self.login("owner", "pass")
 
     def test_set_displayname(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
             content=json.dumps({"displayname": "test"}),
@@ -202,7 +202,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
     def test_set_displayname_too_long(self):
         """Attempts to set a stupid displayname should get a 400"""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
             content=json.dumps({"displayname": "test" * 100}),
@@ -214,9 +214,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(res, "owner")
 
     def get_displayname(self):
-        request, channel = self.make_request(
-            "GET", "/profile/%s/displayname" % (self.owner,)
-        )
+        channel = self.make_request("GET", "/profile/%s/displayname" % (self.owner,))
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body["displayname"]
 
@@ -278,7 +276,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
         )
 
     def request_profile(self, expected_code, url_suffix="", access_token=None):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.profile_url + url_suffix, access_token=access_token
         )
         self.assertEqual(channel.code, expected_code, channel.result)
@@ -320,19 +318,19 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
         """Tests that a user can lookup their own profile without having to be in a room
         if 'require_auth_for_profile_requests' is set to true in the server's config.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/profile/" + self.requester, access_token=self.requester_tok
         )
         self.assertEqual(channel.code, 200, channel.result)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/profile/" + self.requester + "/displayname",
             access_token=self.requester_tok,
         )
         self.assertEqual(channel.code, 200, channel.result)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/profile/" + self.requester + "/avatar_url",
             access_token=self.requester_tok,
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py
index 7add5523c8..2bc512d75e 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/v1/test_push_rule_attrs.py
@@ -45,13 +45,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET enabled for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -74,13 +74,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # disable the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": False},
@@ -89,26 +89,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # check rule disabled
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["enabled"], False)
 
         # DELETE the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "/pushrules/global/override/best.friend", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET enabled for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -130,13 +130,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # disable the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": False},
@@ -145,14 +145,14 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # check rule disabled
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["enabled"], False)
 
         # re-enable the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": True},
@@ -161,7 +161,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # check rule enabled
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -182,32 +182,32 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET enabled for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # DELETE the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "/pushrules/global/override/best.friend", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # check 404 for deleted rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -221,7 +221,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -235,7 +235,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/enabled",
             {"enabled": True},
@@ -252,7 +252,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/.m.muahahah/enabled",
             {"enabled": True},
@@ -276,13 +276,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # GET actions for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/actions", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -305,13 +305,13 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # change the rule actions
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/actions",
             {"actions": ["dont_notify"]},
@@ -320,7 +320,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # GET actions for that new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/actions", access_token=token
         )
         self.assertEqual(channel.code, 200)
@@ -341,26 +341,26 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         }
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
         # PUT a new rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/pushrules/global/override/best.friend", body, access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # DELETE the rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "/pushrules/global/override/best.friend", access_token=token
         )
         self.assertEqual(channel.code, 200)
 
         # check 404 for deleted rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/best.friend/enabled", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -374,7 +374,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token
         )
         self.assertEqual(channel.code, 404)
@@ -388,7 +388,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/best.friend/actions",
             {"actions": ["dont_notify"]},
@@ -405,7 +405,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         token = self.login("user", "pass")
 
         # enable & check 404 for never-heard-of rule
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/pushrules/global/override/.m.muahahah/actions",
             {"actions": ["dont_notify"]},
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 49f1073c88..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -26,9 +26,10 @@ from mock import Mock
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.handlers.pagination import PurgeStatus
+from synapse.rest import admin
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias, UserID
+from synapse.types import JsonDict, RoomAlias, UserID, create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -45,7 +46,7 @@ class RoomBase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         self.hs = self.setup_test_homeserver(
-            "red", http_client=None, federation_client=Mock(),
+            "red", federation_http_client=None, federation_client=Mock(),
         )
 
         self.hs.get_federation_handler = Mock()
@@ -83,13 +84,13 @@ class RoomPermissionsTestCase(RoomBase):
         self.created_rmid_msg_path = (
             "rooms/%s/send/m.room.message/a1" % (self.created_rmid)
         ).encode("ascii")
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
         )
         self.assertEquals(200, channel.code, channel.result)
 
         # set topic for public room
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
             b'{"topic":"Public Room Topic"}',
@@ -111,7 +112,7 @@ class RoomPermissionsTestCase(RoomBase):
             )
 
         # send message in uncreated room, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
             msg_content,
@@ -119,24 +120,24 @@ class RoomPermissionsTestCase(RoomBase):
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # send message in created room not joined (no state), expect 403
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # send message in created room and invited, expect 403
         self.helper.invite(
             room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
         )
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # send message in created room and joined, expect 200
         self.helper.join(room=self.created_rmid, user=self.user_id)
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # send message in created room and left, expect 403
         self.helper.leave(room=self.created_rmid, user=self.user_id)
-        request, channel = self.make_request("PUT", send_msg_path(), msg_content)
+        channel = self.make_request("PUT", send_msg_path(), msg_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
     def test_topic_perms(self):
@@ -144,30 +145,30 @@ class RoomPermissionsTestCase(RoomBase):
         topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
 
         # set/get topic in uncreated room, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
         )
         self.assertEquals(403, channel.code, msg=channel.result["body"])
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
         )
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set/get topic in created PRIVATE room not joined, expect 403
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set topic in created PRIVATE room and invited, expect 403
         self.helper.invite(
             room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
         )
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # get topic in created PRIVATE room and invited, expect 403
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set/get topic in created PRIVATE room and joined, expect 200
@@ -175,29 +176,29 @@ class RoomPermissionsTestCase(RoomBase):
 
         # Only room ops can set topic by default
         self.helper.auth_user_id = self.rmcreator_id
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.helper.auth_user_id = self.user_id
 
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
 
         # set/get topic in created PRIVATE room and left, expect 403
         self.helper.leave(room=self.created_rmid, user=self.user_id)
-        request, channel = self.make_request("PUT", topic_path, topic_content)
+        channel = self.make_request("PUT", topic_path, topic_content)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
-        request, channel = self.make_request("GET", topic_path)
+        channel = self.make_request("GET", topic_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # get topic in PUBLIC room, not joined, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
         )
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         # set topic in PUBLIC room, not joined, expect 403
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
             topic_content,
@@ -207,7 +208,7 @@ class RoomPermissionsTestCase(RoomBase):
     def _test_get_membership(self, room=None, members=[], expect_code=None):
         for member in members:
             path = "/rooms/%s/state/m.room.member/%s" % (room, member)
-            request, channel = self.make_request("GET", path)
+            channel = self.make_request("GET", path)
             self.assertEquals(expect_code, channel.code)
 
     def test_membership_basic_room_perms(self):
@@ -379,16 +380,16 @@ class RoomsMemberListTestCase(RoomBase):
 
     def test_get_member_list(self):
         room_id = self.helper.create_room_as(self.user_id)
-        request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+        channel = self.make_request("GET", "/rooms/%s/members" % room_id)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_room(self):
-        request, channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
+        channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_permission(self):
         room_id = self.helper.create_room_as("@some_other_guy:red")
-        request, channel = self.make_request("GET", "/rooms/%s/members" % room_id)
+        channel = self.make_request("GET", "/rooms/%s/members" % room_id)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_mixed_memberships(self):
@@ -397,17 +398,17 @@ class RoomsMemberListTestCase(RoomBase):
         room_path = "/rooms/%s/members" % room_id
         self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
         # can't see list if you're just invited.
-        request, channel = self.make_request("GET", room_path)
+        channel = self.make_request("GET", room_path)
         self.assertEquals(403, channel.code, msg=channel.result["body"])
 
         self.helper.join(room=room_id, user=self.user_id)
         # can see list now joined
-        request, channel = self.make_request("GET", room_path)
+        channel = self.make_request("GET", room_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         self.helper.leave(room=room_id, user=self.user_id)
         # can see old list once left
-        request, channel = self.make_request("GET", room_path)
+        channel = self.make_request("GET", room_path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
 
@@ -418,30 +419,26 @@ class RoomsCreateTestCase(RoomBase):
 
     def test_post_room_no_keys(self):
         # POST with no config keys, expect new room id
-        request, channel = self.make_request("POST", "/createRoom", "{}")
+        channel = self.make_request("POST", "/createRoom", "{}")
 
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_visibility_key(self):
         # POST with visibility config key, expect new room id
-        request, channel = self.make_request(
-            "POST", "/createRoom", b'{"visibility":"private"}'
-        )
+        channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
         self.assertEquals(200, channel.code)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_custom_key(self):
         # POST with custom config keys, expect new room id
-        request, channel = self.make_request(
-            "POST", "/createRoom", b'{"custom":"stuff"}'
-        )
+        channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
         self.assertEquals(200, channel.code)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_known_and_unknown_keys(self):
         # POST with custom + known config keys, expect new room id
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
         )
         self.assertEquals(200, channel.code)
@@ -449,16 +446,16 @@ class RoomsCreateTestCase(RoomBase):
 
     def test_post_room_invalid_content(self):
         # POST with invalid content / paths, expect 400
-        request, channel = self.make_request("POST", "/createRoom", b'{"visibili')
+        channel = self.make_request("POST", "/createRoom", b'{"visibili')
         self.assertEquals(400, channel.code)
 
-        request, channel = self.make_request("POST", "/createRoom", b'["hello"]')
+        channel = self.make_request("POST", "/createRoom", b'["hello"]')
         self.assertEquals(400, channel.code)
 
     def test_post_room_invitees_invalid_mxid(self):
         # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
         # Note the trailing space in the MXID here!
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
         )
         self.assertEquals(400, channel.code)
@@ -476,54 +473,54 @@ class RoomTopicTestCase(RoomBase):
 
     def test_invalid_puts(self):
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", self.path, "{}")
+        channel = self.make_request("PUT", self.path, "{}")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
+        channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, '{"nao')
+        channel = self.make_request("PUT", self.path, '{"nao')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
         )
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, "text only")
+        channel = self.make_request("PUT", self.path, "text only")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", self.path, "")
+        channel = self.make_request("PUT", self.path, "")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
         # valid key, wrong type
         content = '{"topic":["Topic name"]}'
-        request, channel = self.make_request("PUT", self.path, content)
+        channel = self.make_request("PUT", self.path, content)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_topic(self):
         # nothing should be there
-        request, channel = self.make_request("GET", self.path)
+        channel = self.make_request("GET", self.path)
         self.assertEquals(404, channel.code, msg=channel.result["body"])
 
         # valid put
         content = '{"topic":"Topic name"}'
-        request, channel = self.make_request("PUT", self.path, content)
+        channel = self.make_request("PUT", self.path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # valid get
-        request, channel = self.make_request("GET", self.path)
+        channel = self.make_request("GET", self.path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(content), channel.json_body)
 
     def test_rooms_topic_with_extra_keys(self):
         # valid put with extra keys
         content = '{"topic":"Seasons","subtopic":"Summer"}'
-        request, channel = self.make_request("PUT", self.path, content)
+        channel = self.make_request("PUT", self.path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # valid get
-        request, channel = self.make_request("GET", self.path)
+        channel = self.make_request("GET", self.path)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(content), channel.json_body)
 
@@ -539,24 +536,22 @@ class RoomMemberStateTestCase(RoomBase):
     def test_invalid_puts(self):
         path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", path, "{}")
+        channel = self.make_request("PUT", path, "{}")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, '{"_name":"bo"}')
+        channel = self.make_request("PUT", path, '{"_name":"bo"}')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, '{"nao')
+        channel = self.make_request("PUT", path, '{"nao')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request(
-            "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
-        )
+        channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, "text only")
+        channel = self.make_request("PUT", path, "text only")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, "")
+        channel = self.make_request("PUT", path, "")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
         # valid keys, wrong types
@@ -565,7 +560,7 @@ class RoomMemberStateTestCase(RoomBase):
             Membership.JOIN,
             Membership.LEAVE,
         )
-        request, channel = self.make_request("PUT", path, content.encode("ascii"))
+        channel = self.make_request("PUT", path, content.encode("ascii"))
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_members_self(self):
@@ -576,10 +571,10 @@ class RoomMemberStateTestCase(RoomBase):
 
         # valid join message (NOOP since we made the room)
         content = '{"membership":"%s"}' % Membership.JOIN
-        request, channel = self.make_request("PUT", path, content.encode("ascii"))
+        channel = self.make_request("PUT", path, content.encode("ascii"))
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("GET", path, None)
+        channel = self.make_request("GET", path, None)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         expected_response = {"membership": Membership.JOIN}
@@ -594,10 +589,10 @@ class RoomMemberStateTestCase(RoomBase):
 
         # valid invite message
         content = '{"membership":"%s"}' % Membership.INVITE
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("GET", path, None)
+        channel = self.make_request("GET", path, None)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assertEquals(json.loads(content), channel.json_body)
 
@@ -613,18 +608,54 @@ class RoomMemberStateTestCase(RoomBase):
             Membership.INVITE,
             "Join us!",
         )
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("GET", path, None)
+        channel = self.make_request("GET", path, None)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
         self.assertEquals(json.loads(content), channel.json_body)
 
 
+class RoomInviteRatelimitTestCase(RoomBase):
+    user_id = "@sid1:red"
+
+    servlets = [
+        admin.register_servlets,
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    @unittest.override_config(
+        {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_rooms_ratelimit(self):
+        """Tests that invites in a room are actually rate-limited."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        for i in range(3):
+            self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+        self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invites_by_users_ratelimit(self):
+        """Tests that invites to a specific user are actually rate-limited."""
+
+        for i in range(3):
+            room_id = self.helper.create_room_as(self.user_id)
+            self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+        room_id = self.helper.create_room_as(self.user_id)
+        self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
 class RoomJoinRatelimitTestCase(RoomBase):
     user_id = "@sid1:red"
 
     servlets = [
+        admin.register_servlets,
         profile.register_servlets,
         room.register_servlets,
     ]
@@ -666,7 +697,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
 
         # Update the display name for the user.
         path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
-        request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+        channel = self.make_request("PUT", path, {"displayname": "John Doe"})
         self.assertEquals(channel.code, 200, channel.json_body)
 
         # Check that all the rooms have been sent a profile update into.
@@ -676,7 +707,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
                 self.user_id,
             )
 
-            request, channel = self.make_request("GET", path)
+            channel = self.make_request("GET", path)
             self.assertEquals(channel.code, 200)
 
             self.assertIn("displayname", channel.json_body)
@@ -700,9 +731,23 @@ class RoomJoinRatelimitTestCase(RoomBase):
             # Make sure we send more requests than the rate-limiting config would allow
             # if all of these requests ended up joining the user to a room.
             for i in range(4):
-                request, channel = self.make_request("POST", path % room_id, {})
+                channel = self.make_request("POST", path % room_id, {})
                 self.assertEquals(channel.code, 200)
 
+    @unittest.override_config(
+        {
+            "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
+            "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
+            "autocreate_auto_join_rooms": True,
+        },
+    )
+    def test_autojoin_rooms(self):
+        user_id = self.register_user("testuser", "password")
+
+        # Check that the new user successfully joined the four rooms
+        rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+        self.assertEqual(len(rooms), 4)
+
 
 class RoomMessagesTestCase(RoomBase):
     """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -715,42 +760,40 @@ class RoomMessagesTestCase(RoomBase):
     def test_invalid_puts(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
         # missing keys or invalid json
-        request, channel = self.make_request("PUT", path, b"{}")
+        channel = self.make_request("PUT", path, b"{}")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b'{"_name":"bo"}')
+        channel = self.make_request("PUT", path, b'{"_name":"bo"}')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b'{"nao')
+        channel = self.make_request("PUT", path, b'{"nao')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request(
-            "PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]'
-        )
+        channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b"text only")
+        channel = self.make_request("PUT", path, b"text only")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
-        request, channel = self.make_request("PUT", path, b"")
+        channel = self.make_request("PUT", path, b"")
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_messages_sent(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
 
         content = b'{"body":"test","msgtype":{"type":"a"}}'
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(400, channel.code, msg=channel.result["body"])
 
         # custom message types
         content = b'{"body":"test","msgtype":"test.custom.text"}'
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
         # m.text message type
         path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
         content = b'{"body":"test2","msgtype":"m.text"}'
-        request, channel = self.make_request("PUT", path, content)
+        channel = self.make_request("PUT", path, content)
         self.assertEquals(200, channel.code, msg=channel.result["body"])
 
 
@@ -764,9 +807,7 @@ class RoomInitialSyncTestCase(RoomBase):
         self.room_id = self.helper.create_room_as(self.user_id)
 
     def test_initial_sync(self):
-        request, channel = self.make_request(
-            "GET", "/rooms/%s/initialSync" % self.room_id
-        )
+        channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
         self.assertEquals(200, channel.code)
 
         self.assertEquals(self.room_id, channel.json_body["room_id"])
@@ -807,7 +848,7 @@ class RoomMessageListTestCase(RoomBase):
 
     def test_topo_token_is_accepted(self):
         token = "t1-0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
         )
         self.assertEquals(200, channel.code)
@@ -818,7 +859,7 @@ class RoomMessageListTestCase(RoomBase):
 
     def test_stream_token_is_accepted_for_fwd_pagianation(self):
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
         )
         self.assertEquals(200, channel.code)
@@ -851,7 +892,7 @@ class RoomMessageListTestCase(RoomBase):
         self.helper.send(self.room_id, "message 3")
 
         # Check that we get the first and second message when querying /messages.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
             % (
@@ -879,7 +920,7 @@ class RoomMessageListTestCase(RoomBase):
 
         # Check that we only get the second message through /message now that the first
         # has been purged.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
             % (
@@ -896,7 +937,7 @@ class RoomMessageListTestCase(RoomBase):
         # Check that we get no event, but also no error, when querying /messages with
         # the token that was pointing at the first event, because we don't have it
         # anymore.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
             % (
@@ -955,7 +996,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
         self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
         self.helper.send(self.room, body="There!", tok=self.other_access_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/search?access_token=%s" % (self.access_token,),
             {
@@ -984,7 +1025,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
         self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
         self.helper.send(self.room, body="There!", tok=self.other_access_token)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/search?access_token=%s" % (self.access_token,),
             {
@@ -1032,14 +1073,14 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def test_restricted_no_auth(self):
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
         self.assertEqual(channel.code, 401, channel.result)
 
     def test_restricted_auth(self):
         self.register_user("user", "pass")
         tok = self.login("user", "pass")
 
-        request, channel = self.make_request("GET", self.url, access_token=tok)
+        channel = self.make_request("GET", self.url, access_token=tok)
         self.assertEqual(channel.code, 200, channel.result)
 
 
@@ -1067,7 +1108,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
         self.displayname = "test user"
         data = {"displayname": self.displayname}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,),
             request_data,
@@ -1080,7 +1121,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
     def test_per_room_profile_forbidden(self):
         data = {"membership": "join", "displayname": "other test user"}
         request_data = json.dumps(data)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
             % (self.room_id, self.user_id),
@@ -1090,7 +1131,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
         event_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
             access_token=self.tok,
@@ -1123,7 +1164,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
 
     def test_join_reason(self):
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
             content={"reason": reason},
@@ -1137,7 +1178,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
             content={"reason": reason},
@@ -1151,7 +1192,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1165,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1177,7 +1218,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
 
     def test_unban_reason(self):
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1189,7 +1230,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
 
     def test_invite_reason(self):
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
             content={"reason": reason, "user_id": self.second_user_id},
@@ -1208,7 +1249,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         )
 
         reason = "hello"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
             content={"reason": reason},
@@ -1219,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
         self._check_for_reason(reason)
 
     def _check_for_reason(self, reason):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
                 self.room_id, self.second_user_id
@@ -1268,7 +1309,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         """Test that we can filter by a label on a /context request."""
         event_id = self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/context/%s?filter=%s"
             % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)),
@@ -1298,7 +1339,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         """Test that we can filter by the absence of a label on a /context request."""
         event_id = self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/context/%s?filter=%s"
             % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1333,7 +1374,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         """
         event_id = self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/context/%s?filter=%s"
             % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)),
@@ -1361,7 +1402,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         self._send_labelled_messages_in_room()
 
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
             % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)),
@@ -1378,7 +1419,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         self._send_labelled_messages_in_room()
 
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
             % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)),
@@ -1401,7 +1442,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
         self._send_labelled_messages_in_room()
 
         token = "s0_0_0_0_0_0_0_0_0"
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/messages?access_token=%s&from=%s&filter=%s"
             % (
@@ -1432,7 +1473,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
 
         self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/search?access_token=%s" % self.tok, request_data
         )
 
@@ -1467,7 +1508,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
 
         self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/search?access_token=%s" % self.tok, request_data
         )
 
@@ -1514,7 +1555,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
 
         self._send_labelled_messages_in_room()
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/search?access_token=%s" % self.tok, request_data
         )
 
@@ -1635,7 +1676,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
 
         # Check that we can still see the messages before the erasure request.
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
             % (self.room_id, event_id),
@@ -1681,7 +1722,9 @@ class ContextTestCase(unittest.HomeserverTestCase):
 
         deactivate_account_handler = self.hs.get_deactivate_account_handler()
         self.get_success(
-            deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
+            deactivate_account_handler.deactivate_account(
+                self.user_id, True, create_requester(self.user_id)
+            )
         )
 
         # Invite another user in the room. This is needed because messages will be
@@ -1699,7 +1742,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
         # Check that a user that joined the room after the erasure request can't see
         # the messages anymore.
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             '/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
             % (self.room_id, event_id),
@@ -1789,7 +1832,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
 
     def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
         """Calls the endpoint under test. returns the json response object."""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases"
             % (self.room_id,),
@@ -1810,7 +1853,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.room_owner_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
@@ -1840,14 +1883,14 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
         data = {"room_id": self.room_id}
         request_data = json.dumps(data)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", url, request_data, access_token=self.room_owner_tok
         )
         self.assertEqual(channel.code, expected_code, channel.result)
 
     def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
         """Calls the endpoint under test. returns the json response object."""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
             access_token=self.room_owner_tok,
@@ -1859,7 +1902,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
 
     def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
         """Calls the endpoint under test. returns the json response object."""
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
             json.dumps(content),
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index bbd30f594b..38c51525a3 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -39,7 +39,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         hs = self.setup_test_homeserver(
-            "red", http_client=None, federation_client=Mock(),
+            "red", federation_http_client=None, federation_client=Mock(),
         )
 
         self.event_source = hs.get_event_sources().sources["typing"]
@@ -94,7 +94,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         self.helper.join(self.room_id, user="@jim:red")
 
     def test_set_typing(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
@@ -117,7 +117,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         )
 
     def test_set_not_typing(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": false}',
@@ -125,7 +125,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
     def test_typing_timeout(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
@@ -138,7 +138,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 2)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..b1333df82d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,8 +17,12 @@
 # limitations under the License.
 
 import json
+import re
 import time
-from typing import Any, Dict, Optional
+import urllib.parse
+from typing import Any, Dict, Mapping, MutableMapping, Optional
+
+from mock import patch
 
 import attr
 
@@ -26,8 +30,11 @@ from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.types import JsonDict
 
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
+from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
 
 
 @attr.s
@@ -75,7 +82,7 @@ class RestHelper:
         if tok:
             path = path + "?access_token=%s" % tok
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "POST",
@@ -151,7 +158,7 @@ class RestHelper:
         data = {"membership": membership}
         data.update(extra_data)
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "PUT",
@@ -186,7 +193,7 @@ class RestHelper:
         if tok:
             path = path + "?access_token=%s" % tok
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "PUT",
@@ -242,9 +249,7 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        _, channel = make_request(
-            self.hs.get_reactor(), self.site, method, path, content
-        )
+        channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -327,7 +332,7 @@ class RestHelper:
         """
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             FakeSite(resource),
             "POST",
@@ -344,3 +349,246 @@ class RestHelper:
         )
 
         return channel.json_body
+
+    def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+        """Log in (as a new user) via OIDC
+
+        Returns the result of the final token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+        client_redirect_url = "https://x"
+        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+
+        # expect a confirmation page
+        assert channel.code == 200, channel.result
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.text_body,
+        )
+        assert m, channel.text_body
+        login_token = m.group(1)
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token and device id.
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        assert channel.code == 200
+        return channel.json_body
+
+    def auth_via_oidc(
+        self,
+        user_info_dict: JsonDict,
+        client_redirect_url: Optional[str] = None,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> FakeChannel:
+        """Perform an OIDC authentication flow via a mock OIDC provider.
+
+        This can be used for either login or user-interactive auth.
+
+        Starts by making a request to the relevant synapse redirect endpoint, which is
+        expected to serve a 302 to the OIDC provider. We then make a request to the
+        OIDC callback endpoint, intercepting the HTTP requests that will get sent back
+        to the OIDC provider.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+
+        Args:
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+            client_redirect_url: for a login flow, the client redirect URL to pass to
+                the login redirect endpoint
+            ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
+                of the UI auth.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+            Note that the response code may be a 200, 302 or 400 depending on how things
+            went.
+        """
+
+        cookies = {}
+
+        # if we're doing a ui auth, hit the ui auth redirect endpoint
+        if ui_auth_session_id:
+            # can't set the client redirect url for UI Auth
+            assert client_redirect_url is None
+            oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+        else:
+            # otherwise, hit the login redirect endpoint
+            oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+
+        # we now have a URI for the OIDC IdP, but we skip that and go straight
+        # back to synapse's OIDC callback resource. However, we do need the "state"
+        # param that synapse passes to the IdP via query params, as well as the cookie
+        # that synapse passes to the client.
+
+        oauth_uri_path, _ = oauth_uri.split("?", 1)
+        assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+            "unexpected SSO URI " + oauth_uri_path
+        )
+        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+    def complete_oidc_auth(
+        self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+    ) -> FakeChannel:
+        """Mock out an OIDC authentication flow
+
+        Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+        initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+        Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+        sent back to the OIDC provider.
+
+        Requires the OIDC callback resource to be mounted at the normal place.
+
+        Args:
+            oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+               from initiate_sso_login or initiate_sso_ui_auth).
+            cookies: the cookies set by synapse's redirect endpoint, which will be
+               sent back to the callback endpoint.
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+        """
+        _, oauth_uri_qs = oauth_uri.split("?", 1)
+        params = urllib.parse.parse_qs(oauth_uri_qs)
+        callback_uri = "%s?%s" % (
+            urllib.parse.urlparse(params["redirect_uri"][0]).path,
+            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+        )
+
+        # before we hit the callback uri, stub out some methods in the http client so
+        # that we don't have to handle full HTTPS requests.
+        # (expected url, json response) pairs, in the order we expect them.
+        expected_requests = [
+            # first we get a hit to the token endpoint, which we tell to return
+            # a dummy OIDC access token
+            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
+            # and then one to the user_info endpoint, which returns our remote user id.
+            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
+        ]
+
+        async def mock_req(method: str, uri: str, data=None, headers=None):
+            (expected_uri, resp_obj) = expected_requests.pop(0)
+            assert uri == expected_uri
+            resp = FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+            )
+            return resp
+
+        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+            # now hit the callback URI with the right params and a made-up code
+            channel = make_request(
+                self.hs.get_reactor(),
+                self.site,
+                "GET",
+                callback_uri,
+                custom_headers=[
+                    ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+                ],
+            )
+        return channel
+
+    def initiate_sso_login(
+        self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the login-via-sso redirect endpoint, and return the target
+
+        Assumes that exactly one SSO provider has been configured. Requires the login
+        servlet to be mounted.
+
+        Args:
+            client_redirect_url: the client redirect URL to pass to the login redirect
+                endpoint
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets redirected to (ie, the SSO server)
+        """
+        params = {}
+        if client_redirect_url:
+            params["redirectUrl"] = client_redirect_url
+
+        # hit the redirect url (which will issue a cookie and state)
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+        )
+
+        assert channel.code == 302
+        channel.extract_cookies(cookies)
+        return channel.headers.getRawHeaders("Location")[0]
+
+    def initiate_sso_ui_auth(
+        self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the ui-auth-via-sso endpoint, and return the target
+
+        Assumes that exactly one SSO provider has been configured. Requires the
+        AuthRestServlet to be mounted.
+
+        Args:
+            ui_auth_session_id: the session id of the UI auth
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets linked to (ie, the SSO server)
+        """
+        sso_redirect_endpoint = (
+            "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
+            + urllib.parse.urlencode({"session": ui_auth_session_id})
+        )
+        # hit the redirect url (which will issue a cookie and state)
+        channel = make_request(
+            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
+        )
+        # that should serve a confirmation page
+        assert channel.code == 200, channel.text_body
+        channel.extract_cookies(cookies)
+
+        # parse the confirmation page to fish out the link.
+        p = TestHtmlParser()
+        p.feed(channel.text_body)
+        p.close()
+        assert len(p.links) == 1, "not exactly one link in confirmation page"
+        oauth_uri = p.links[0]
+        return oauth_uri
+
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "discover": False,
+    "issuer": "https://issuer.test",
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["profile"],
+    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
+    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 2ac1ecb7d3..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -19,13 +19,12 @@ import os
 import re
 from email.parser import Parser
 from typing import Optional
-from urllib.parse import urlencode
 
 import pkg_resources
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -113,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_email(self):
+        """Test that we ratelimit /requestToken for the same email.
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test1@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        def reset(ip):
+            client_secret = "foobar"
+            session_id = self._request_token(email, client_secret, ip)
+
+            self.assertEquals(len(self.email_attempts), 1)
+            link = self._get_link_from_email()
+
+            self._validate_token(link)
+
+            self._reset_password(new_password, session_id, client_secret)
+
+            self.email_attempts.clear()
+
+        # We expect to be able to make three requests before getting rate
+        # limited.
+        #
+        # We change IPs to ensure that we're not being ratelimited due to the
+        # same IP
+        reset("127.0.0.1")
+        reset("127.0.0.2")
+        reset("127.0.0.3")
+
+        with self.assertRaises(HttpResponseException) as cm:
+            reset("127.0.0.4")
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_basic_password_reset_canonicalise_email(self):
         """Test basic password reset flow
         Request password reset with different spelling
@@ -240,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
-    def _request_token(self, email, client_secret):
-        request, channel = self.make_request(
+    def _request_token(self, email, client_secret, ip="127.0.0.1"):
+        channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            client_ip=ip,
         )
-        self.assertEquals(200, channel.code, channel.result)
+
+        if channel.code != 200:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body["sid"]
 
@@ -255,7 +309,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         path = link.replace("https://example.com", "")
 
         # Load the password reset confirmation page
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.submit_token_resource),
             "GET",
@@ -268,20 +322,13 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Now POST to the same endpoint, mimicking the same behaviour as clicking the
         # password reset confirm button
 
-        # Send arguments as url-encoded form data, matching the template's behaviour
-        form_args = []
-        for key, value_list in request.args.items():
-            for value in value_list:
-                arg = (key, value)
-                form_args.append(arg)
-
         # Confirm the password reset
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.submit_token_resource),
             "POST",
             path,
-            content=urlencode(form_args).encode("utf8"),
+            content=b"",
             shorthand=False,
             content_is_form=True,
         )
@@ -310,7 +357,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
     def _reset_password(
         self, new_password, session_id, client_secret, expected_code=200
     ):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/password",
             {
@@ -352,8 +399,8 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
 
         # Check that this access token has been invalidated.
-        request, channel = self.make_request("GET", "account/whoami")
-        self.assertEqual(request.code, 401)
+        channel = self.make_request("GET", "account/whoami")
+        self.assertEqual(channel.code, 401)
 
     def test_pending_invites(self):
         """Tests that deactivating a user rejects every pending invite for them."""
@@ -407,10 +454,10 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
                 "erase": False,
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "account/deactivate", request_data, access_token=tok
         )
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
 
 
 class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
@@ -517,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def test_address_trim(self):
         self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_ip(self):
+        """Tests that adding emails is ratelimited by IP
+        """
+
+        # We expect to be able to set three emails before getting ratelimited.
+        self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+        self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+        self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+        with self.assertRaises(HttpResponseException) as cm:
+            self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_add_email_if_disabled(self):
         """Test adding email to profile when doing so is disallowed
         """
@@ -530,7 +592,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         self._validate_token(link)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -548,7 +610,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -569,7 +631,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/3pid/delete",
             {"medium": "email", "address": self.email},
@@ -578,7 +640,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -601,7 +663,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/3pid/delete",
             {"medium": "email", "address": self.email},
@@ -612,7 +674,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -629,7 +691,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEquals(len(self.email_attempts), 1)
 
         # Attempt to add email without clicking the link
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -647,7 +709,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -662,7 +724,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         session_id = "weasle"
 
         # Attempt to add email without even requesting an email
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -680,7 +742,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
@@ -784,17 +846,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         if next_link:
             body["next_link"] = next_link
 
-        request, channel = self.make_request(
-            "POST", b"account/3pid/email/requestToken", body,
-        )
-        self.assertEquals(expect_code, channel.code, channel.result)
+        channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
+
+        if channel.code != expect_code:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body.get("sid")
 
     def _request_token_invalid_email(
         self, email, expected_errcode, expected_error, client_secret="foobar",
     ):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"account/3pid/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
@@ -807,7 +871,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         # Remove the host
         path = link.replace("https://example.com", "")
 
-        request, channel = self.make_request("GET", path, shorthand=False)
+        channel = self.make_request("GET", path, shorthand=False)
         self.assertEquals(200, channel.code, channel.result)
 
     def _get_link_from_email(self):
@@ -833,15 +897,17 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def _add_email(self, request_email, expected_email):
         """Test adding an email to profile
         """
+        previous_email_attempts = len(self.email_attempts)
+
         client_secret = "foobar"
         session_id = self._request_token(request_email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"/_matrix/client/unstable/account/3pid/add",
             {
@@ -859,10 +925,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         # Get user
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url_3pid, access_token=self.user_id_tok,
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+        threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+        self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..3f50c56745 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 New Vector
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,20 +13,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Union
+from typing import Union
 
 from twisted.internet.defer import succeed
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.http.site import SynapseRequest
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
+from tests.handlers.test_oidc import HAS_OIDC
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
+from tests.unittest import override_config, skip_unless
 
 
 class DummyRecaptchaChecker(UserInteractiveAuthChecker):
@@ -64,11 +68,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
 
     def register(self, expected_response: int, body: JsonDict) -> FakeChannel:
         """Make a register request."""
-        request, channel = self.make_request(
-            "POST", "register", body
-        )  # type: SynapseRequest, FakeChannel
+        channel = self.make_request("POST", "register", body)
 
-        self.assertEqual(request.code, expected_response)
+        self.assertEqual(channel.code, expected_response)
         return channel
 
     def recaptcha(
@@ -78,18 +80,18 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         if post_session is None:
             post_session = session
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request.code, 200)
+        )
+        self.assertEqual(channel.code, 200)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "auth/m.login.recaptcha/fallback/web?session="
             + post_session
             + "&g-recaptcha-response=a",
         )
-        self.assertEqual(request.code, expected_post_response)
+        self.assertEqual(channel.code, expected_post_response)
 
         # The recaptcha handler is called with the response given
         attempts = self.recaptcha_checker.recaptcha_attempts
@@ -156,31 +158,44 @@ class UIAuthTests(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = "https://synapse.test"
+
+        if HAS_OIDC:
+            # we enable OIDC as a way of testing SSO flows
+            oidc_config = {}
+            oidc_config.update(TEST_OIDC_CONFIG)
+            oidc_config["allow_existing_users"] = True
+            config["oidc_config"] = oidc_config
+
+        return config
+
+    def create_resource_dict(self):
+        resource_dict = super().create_resource_dict()
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
+        return resource_dict
+
     def prepare(self, reactor, clock, hs):
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
-        self.user_tok = self.login("test", self.user_pass)
-
-    def get_device_ids(self) -> List[str]:
-        # Get the list of devices so one can be deleted.
-        request, channel = self.make_request(
-            "GET", "devices", access_token=self.user_tok,
-        )  # type: SynapseRequest, FakeChannel
-
-        # Get the ID of the device.
-        self.assertEqual(request.code, 200)
-        return [d["device_id"] for d in channel.json_body["devices"]]
+        self.device_id = "dev1"
+        self.user_tok = self.login("test", self.user_pass, self.device_id)
 
     def delete_device(
-        self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+        self,
+        access_token: str,
+        device: str,
+        expected_response: int,
+        body: Union[bytes, JsonDict] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
-        request, channel = self.make_request(
-            "DELETE", "devices/" + device, body, access_token=self.user_tok
-        )  # type: SynapseRequest, FakeChannel
+        channel = self.make_request(
+            "DELETE", "devices/" + device, body, access_token=access_token,
+        )
 
         # Ensure the response is sane.
-        self.assertEqual(request.code, expected_response)
+        self.assertEqual(channel.code, expected_response)
 
         return channel
 
@@ -188,12 +203,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """Delete 1 or more devices."""
         # Note that this uses the delete_devices endpoint so that we can modify
         # the payload half-way through some tests.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "delete_devices", body, access_token=self.user_tok,
-        )  # type: SynapseRequest, FakeChannel
+        )
 
         # Ensure the response is sane.
-        self.assertEqual(request.code, expected_response)
+        self.assertEqual(channel.code, expected_response)
 
         return channel
 
@@ -201,11 +216,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """
         Test user interactive authentication outside of registration.
         """
-        device_id = self.get_device_ids()[0]
-
         # Attempt to delete this device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_id, 401)
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -214,7 +227,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Make another request providing the UI auth flow.
         self.delete_device(
-            device_id,
+            self.user_tok,
+            self.device_id,
             200,
             {
                 "auth": {
@@ -233,13 +247,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         UIA - check that still works.
         """
 
-        device_id = self.get_device_ids()[0]
-        channel = self.delete_device(device_id, 401)
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
         session = channel.json_body["session"]
 
         # Make another request providing the UI auth flow.
         self.delete_device(
-            device_id,
+            self.user_tok,
+            self.device_id,
             200,
             {
                 "auth": {
@@ -262,14 +276,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
         session ID should be rejected.
         """
         # Create a second login.
-        self.login("test", self.user_pass)
-
-        device_ids = self.get_device_ids()
-        self.assertEqual(len(device_ids), 2)
+        self.login("test", self.user_pass, "dev2")
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_devices(401, {"devices": [device_ids[0]]})
+        channel = self.delete_devices(401, {"devices": [self.device_id]})
 
         # Grab the session
         session = channel.json_body["session"]
@@ -281,7 +292,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_devices(
             200,
             {
-                "devices": [device_ids[1]],
+                "devices": ["dev2"],
                 "auth": {
                     "type": "m.login.password",
                     "identifier": {"type": "m.id.user", "user": self.user},
@@ -296,14 +307,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
         The initial requested URI cannot be modified during the user interactive authentication session.
         """
         # Create a second login.
-        self.login("test", self.user_pass)
-
-        device_ids = self.get_device_ids()
-        self.assertEqual(len(device_ids), 2)
+        self.login("test", self.user_pass, "dev2")
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_ids[0], 401)
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -312,8 +320,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Make another request providing the UI auth flow, but try to delete the
         # second device. This results in an error.
+        #
+        # This makes use of the fact that the device ID is embedded into the URL.
         self.delete_device(
-            device_ids[1],
+            self.user_tok,
+            "dev2",
             403,
             {
                 "auth": {
@@ -324,3 +335,152 @@ class UIAuthTests(unittest.HomeserverTestCase):
                 },
             },
         )
+
+    @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}})
+    def test_can_reuse_session(self):
+        """
+        The session can be reused if configured.
+
+        Compare to test_cannot_change_uri.
+        """
+        # Create a second and third login.
+        self.login("test", self.user_pass, "dev2")
+        self.login("test", self.user_pass, "dev3")
+
+        # Attempt to delete a device. This works since the user just logged in.
+        self.delete_device(self.user_tok, "dev2", 200)
+
+        # Move the clock forward past the validation timeout.
+        self.reactor.advance(6)
+
+        # Deleting another devices throws the user into UI auth.
+        channel = self.delete_device(self.user_tok, "dev3", 401)
+
+        # Grab the session
+        session = channel.json_body["session"]
+        # Ensure that flows are what is expected.
+        self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+        # Make another request providing the UI auth flow.
+        self.delete_device(
+            self.user_tok,
+            "dev3",
+            200,
+            {
+                "auth": {
+                    "type": "m.login.password",
+                    "identifier": {"type": "m.id.user", "user": self.user},
+                    "password": self.user_pass,
+                    "session": session,
+                },
+            },
+        )
+
+        # Make another request, but try to delete the first device. This works
+        # due to re-using the previous session.
+        #
+        # Note that *no auth* information is provided, not even a session iD!
+        self.delete_device(self.user_tok, self.device_id, 200)
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_ui_auth_via_sso(self):
+        """Test a successful UI Auth flow via SSO
+
+        This includes:
+          * hitting the UIA SSO redirect endpoint
+          * checking it serves a confirmation page which links to the OIDC provider
+          * calling back to the synapse oidc callback
+          * checking that the original operation succeeds
+        """
+
+        # log the user in
+        remote_user_id = UserID.from_string(self.user).localpart
+        login_resp = self.helper.login_via_oidc(remote_user_id)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        # initiate a UI Auth process by attempting to delete the device
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        # check that SSO is offered
+        flows = channel.json_body["flows"]
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+
+        # run the UIA-via-SSO flow
+        session_id = channel.json_body["session"]
+        channel = self.helper.auth_via_oidc(
+            {"sub": remote_user_id}, ui_auth_session_id=session_id
+        )
+
+        # that should serve a confirmation page
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # and now the delete request should succeed.
+        self.delete_device(
+            self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
+        )
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_does_not_offer_password_for_sso_user(self):
+        login_resp = self.helper.login_via_oidc("username")
+        user_tok = login_resp["access_token"]
+        device_id = login_resp["device_id"]
+
+        # now call the device deletion API: we should get the option to auth with SSO
+        # and not password.
+        channel = self.delete_device(user_tok, device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+    def test_does_not_offer_sso_for_password_user(self):
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_offers_both_flows_for_upgraded_user(self):
+        """A user that had a password and then logged in with SSO should get both flows
+        """
+        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        flows = channel.json_body["flows"]
+        # we have no particular expectations of ordering here
+        self.assertIn({"stages": ["m.login.password"]}, flows)
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+        self.assertEqual(len(flows), 2)
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_ui_auth_fails_for_incorrect_sso_user(self):
+        """If the user tries to authenticate with the wrong SSO user, they get an error
+        """
+        # log the user in
+        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        # start a UI Auth flow by attempting to delete a device
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+        session_id = channel.json_body["session"]
+
+        # do the OIDC auth, but auth as the wrong user
+        channel = self.helper.auth_via_oidc(
+            {"sub": "wrong_user"}, ui_auth_session_id=session_id
+        )
+
+        # that should return a failure message
+        self.assertSubstring("We were unable to validate", channel.text_body)
+
+        # ... and the delete op should now fail with a 403
+        self.delete_device(
+            self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+        )
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index 767e126875..e808339fb3 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -36,7 +36,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         return hs
 
     def test_check_auth_required(self):
-        request, channel = self.make_request("GET", self.url)
+        channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.code, 401)
 
@@ -44,7 +44,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.register_user("user", "pass")
         access_token = self.login("user", "pass")
 
-        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
@@ -62,7 +62,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         user = self.register_user(localpart, password)
         access_token = self.login(user, password)
 
-        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
@@ -70,7 +70,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         # Test case where password is handled outside of Synapse
         self.assertTrue(capabilities["m.change_password"]["enabled"])
         self.get_success(self.store.user_set_password_hash(user, None))
-        request, channel = self.make_request("GET", self.url, access_token=access_token)
+        channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 231d5aefea..f761c44936 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -36,7 +36,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastore()
 
     def test_add_filter(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
@@ -49,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEquals(filter.result, self.EXAMPLE_FILTER)
 
     def test_add_filter_for_other_user(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
             self.EXAMPLE_FILTER_JSON,
@@ -61,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
     def test_add_filter_non_local_user(self):
         _is_mine = self.hs.is_mine
         self.hs.is_mine = lambda target_user: False
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
@@ -79,7 +79,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
         self.reactor.advance(1)
         filter_id = filter_id.result
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
         )
 
@@ -87,7 +87,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
 
     def test_get_filter_non_existant(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
         )
 
@@ -97,7 +97,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
     # Currently invalid params do not have an appropriate errcode
     # in errors.py
     def test_get_filter_invalid_id(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
         )
 
@@ -105,7 +105,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     # No ID also returns an invalid_id error
     def test_get_filter_no_id(self):
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
         )
 
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py
index ee86b94917..fba34def30 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/v2_alpha/test_password_policy.py
@@ -70,9 +70,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
     def test_get_policy(self):
         """Tests if the /password_policy endpoint returns the configured policy."""
 
-        request, channel = self.make_request(
-            "GET", "/_matrix/client/r0/password_policy"
-        )
+        channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
 
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(
@@ -89,7 +87,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_too_short(self):
         request_data = json.dumps({"username": "kermit", "password": "shorty"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -98,7 +96,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_digit(self):
         request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -107,7 +105,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_symbol(self):
         request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -116,7 +114,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_uppercase(self):
         request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -125,7 +123,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_no_lowercase(self):
         request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(
@@ -134,7 +132,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
 
     def test_password_compliant(self):
         request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
-        request, channel = self.make_request("POST", self.register_url, request_data)
+        channel = self.make_request("POST", self.register_url, request_data)
 
         # Getting a 401 here means the password has passed validation and the server has
         # responded with a list of registration flows.
@@ -160,7 +158,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
                 },
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/account/password",
             request_data,
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 8f0c2430e8..27db4f551e 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -61,7 +61,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.hs.get_datastore().services_cache.append(appservice)
         request_data = json.dumps({"username": "as_user_kermit"})
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
@@ -72,7 +72,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_appservice_registration_invalid(self):
         self.appservice = None  # no application service exists
         request_data = json.dumps({"username": "kermit"})
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
@@ -80,14 +80,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_POST_bad_password(self):
         request_data = json.dumps({"username": "kermit", "password": 666})
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid password")
 
     def test_POST_bad_username(self):
         request_data = json.dumps({"username": 777, "password": "monkey"})
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"400", channel.result)
         self.assertEquals(channel.json_body["error"], "Invalid username")
@@ -102,7 +102,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "auth": {"type": LoginType.DUMMY},
         }
         request_data = json.dumps(params)
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         det_data = {
             "user_id": user_id,
@@ -117,16 +117,17 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
 
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Registration has been disabled")
+        self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
 
     def test_POST_guest_registration(self):
         self.hs.config.macaroon_secret_key = "test"
         self.hs.config.allow_guest_access = True
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
         self.assertEquals(channel.result["code"], b"200", channel.result)
@@ -135,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_disabled_guest_registration(self):
         self.hs.config.allow_guest_access = False
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
@@ -144,7 +145,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     def test_POST_ratelimiting_guest(self):
         for i in range(0, 6):
             url = self.url + b"?kind=guest"
-            request, channel = self.make_request(b"POST", url, b"{}")
+            channel = self.make_request(b"POST", url, b"{}")
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -154,7 +155,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
@@ -168,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 "auth": {"type": LoginType.DUMMY},
             }
             request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", self.url, request_data)
+            channel = self.make_request(b"POST", self.url, request_data)
 
             if i == 5:
                 self.assertEquals(channel.result["code"], b"429", channel.result)
@@ -178,12 +179,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
-        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_advertised_flows(self):
-        request, channel = self.make_request(b"POST", self.url, b"{}")
+        channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
@@ -206,7 +207,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_advertised_flows_captcha_and_terms_and_3pids(self):
-        request, channel = self.make_request(b"POST", self.url, b"{}")
+        channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
@@ -238,7 +239,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         }
     )
     def test_advertised_flows_no_msisdn_email_required(self):
-        request, channel = self.make_request(b"POST", self.url, b"{}")
+        channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
@@ -278,7 +279,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             b"register/email/requestToken",
             {"client_secret": "foobar", "email": email, "send_attempt": 1},
@@ -317,13 +318,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
 
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(
@@ -345,14 +346,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/account_validity/validity"
         params = {"user_id": user_id}
         request_data = json.dumps(params)
-        request, channel = self.make_request(
-            b"POST", url, request_data, access_token=admin_tok
-        )
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_manual_expire(self):
@@ -369,14 +368,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             "enable_renewal_emails": False,
         }
         request_data = json.dumps(params)
-        request, channel = self.make_request(
-            b"POST", url, request_data, access_token=admin_tok
-        )
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
@@ -396,20 +393,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             "enable_renewal_emails": False,
         }
         request_data = json.dumps(params)
-        request, channel = self.make_request(
-            b"POST", url, request_data, access_token=admin_tok
-        )
+        channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # Try to log the user out
-        request, channel = self.make_request(b"POST", "/logout", access_token=tok)
+        channel = self.make_request(b"POST", "/logout", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # Log the user in again (allowed for expired accounts)
         tok = self.login("kermit", "monkey")
 
         # Try to log out all of the user's sessions
-        request, channel = self.make_request(b"POST", "/logout/all", access_token=tok)
+        channel = self.make_request(b"POST", "/logout/all", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
 
@@ -483,7 +478,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # retrieve the token from the DB.
         renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
         url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
-        request, channel = self.make_request(b"GET", url)
+        channel = self.make_request(b"GET", url)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         # Check that we're getting HTML back.
@@ -503,14 +498,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # our access token should be denied from now, otherwise they should
         # succeed.
         self.reactor.advance(datetime.timedelta(days=3).total_seconds())
-        request, channel = self.make_request(b"GET", "/sync", access_token=tok)
+        channel = self.make_request(b"GET", "/sync", access_token=tok)
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_renewal_invalid_token(self):
         # Hit the renewal endpoint with an invalid token and check that it behaves as
         # expected, i.e. that it responds with 404 Not Found and the correct HTML.
         url = "/_matrix/client/unstable/account_validity/renew?token=123"
-        request, channel = self.make_request(b"GET", url)
+        channel = self.make_request(b"GET", url)
         self.assertEquals(channel.result["code"], b"404", channel.result)
 
         # Check that we're getting HTML back.
@@ -531,7 +526,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST",
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
@@ -555,10 +550,10 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
                 "erase": False,
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "account/deactivate", request_data, access_token=tok
         )
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
 
         self.reactor.advance(datetime.timedelta(days=8).total_seconds())
 
@@ -606,7 +601,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         self.email_attempts = []
 
         # Test that we're still able to manually trigger a mail to be sent.
-        request, channel = self.make_request(
+        channel = self.make_request(
             b"POST",
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 6cd4eb6624..bd574077e7 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -60,7 +60,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         event_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, event_id),
             access_token=self.user_token,
@@ -107,7 +107,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         annotation_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
             % (self.room, self.parent_id),
@@ -152,7 +152,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             if prev_token:
                 from_token = "&from=" + prev_token
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
                 % (self.room, self.parent_id, from_token),
@@ -210,7 +210,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             if prev_token:
                 from_token = "&from=" + prev_token
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
                 % (self.room, self.parent_id, from_token),
@@ -279,7 +279,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             if prev_token:
                 from_token = "&from=" + prev_token
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "/_matrix/client/unstable/rooms/%s"
                 "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
@@ -325,7 +325,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s"
             % (self.room, self.parent_id),
@@ -357,7 +357,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Now lets redact one of the 'a' reactions
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
             access_token=self.user_token,
@@ -365,7 +365,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(200, channel.code, channel.json_body)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s"
             % (self.room, self.parent_id),
@@ -382,7 +382,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         """Test that aggregations must be annotations.
         """
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
             % (self.room, self.parent_id, RelationTypes.REPLACE),
@@ -414,7 +414,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         reply_2 = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
@@ -450,7 +450,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         edit_event_id = channel.json_body["event_id"]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
@@ -507,7 +507,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(200, channel.code, channel.json_body)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
@@ -549,7 +549,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Check the relation is returned
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
             % (self.room, original_event_id),
@@ -561,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(len(channel.json_body["chunk"]), 1)
 
         # Redact the original event
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/redact/%s/%s"
             % (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
@@ -571,7 +571,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Try to check for remaining m.replace relations
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
             % (self.room, original_event_id),
@@ -598,7 +598,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Redact the original
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/rooms/%s/redact/%s/%s"
             % (
@@ -612,7 +612,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
 
         # Check that aggregations returns zero
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
             % (self.room, original_event_id),
@@ -656,7 +656,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         original_id = parent_id if parent_id else self.parent_id
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
             % (self.room, original_id, relation_type, event_type, query),
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
index 562a9c1ba4..116ace1812 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -17,6 +17,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import shared_rooms
 
 from tests import unittest
+from tests.server import FakeChannel
 
 
 class UserSharedRoomsTest(unittest.HomeserverTestCase):
@@ -40,14 +41,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.store = hs.get_datastore()
         self.handler = hs.get_user_directory_handler()
 
-    def _get_shared_rooms(self, token, other_user):
-        request, channel = self.make_request(
+    def _get_shared_rooms(self, token, other_user) -> FakeChannel:
+        return self.make_request(
             "GET",
             "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
             % other_user,
             access_token=token,
         )
-        return request, channel
 
     def test_shared_room_list_public(self):
         """
@@ -63,7 +63,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
         self.helper.join(room, user=u2, tok=u2_token)
 
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 1)
         self.assertEquals(channel.json_body["joined"][0], room)
@@ -82,7 +82,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
         self.helper.join(room, user=u2, tok=u2_token)
 
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 1)
         self.assertEquals(channel.json_body["joined"][0], room)
@@ -104,7 +104,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room_public, user=u2, tok=u2_token)
         self.helper.join(room_private, user=u1, tok=u1_token)
 
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 2)
         self.assertTrue(room_public in channel.json_body["joined"])
@@ -125,13 +125,13 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Assert user directory is not empty
-        request, channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_shared_rooms(u1_token, u2)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 1)
         self.assertEquals(channel.json_body["joined"][0], room)
 
         self.helper.leave(room, user=u1, tok=u1_token)
 
-        request, channel = self._get_shared_rooms(u2_token, u1)
+        channel = self._get_shared_rooms(u2_token, u1)
         self.assertEquals(200, channel.code, channel.result)
         self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 31ac0fccb8..512e36c236 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -35,7 +35,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
     ]
 
     def test_sync_argless(self):
-        request, channel = self.make_request("GET", "/sync")
+        channel = self.make_request("GET", "/sync")
 
         self.assertEqual(channel.code, 200)
         self.assertTrue(
@@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         """
         self.hs.config.use_presence = False
 
-        request, channel = self.make_request("GET", "/sync")
+        channel = self.make_request("GET", "/sync")
 
         self.assertEqual(channel.code, 200)
         self.assertTrue(
@@ -194,7 +194,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             tok=tok,
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/sync?filter=%s" % sync_filter, access_token=tok
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -245,21 +245,19 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         self.helper.send(room, body="There!", tok=other_access_token)
 
         # Start typing.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": true, "timeout": 30000}',
         )
         self.assertEquals(200, channel.code)
 
-        request, channel = self.make_request(
-            "GET", "/sync?access_token=%s" % (access_token,)
-        )
+        channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
         # Stop typing.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": false}',
@@ -267,7 +265,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
         # Start typing.
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": true, "timeout": 30000}',
@@ -275,9 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
         # Should return immediately
-        request, channel = self.make_request(
-            "GET", sync_url % (access_token, next_batch)
-        )
+        channel = self.make_request("GET", sync_url % (access_token, next_batch))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
@@ -289,9 +285,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         # invalidate the stream token.
         self.helper.send(room, body="There!", tok=other_access_token)
 
-        request, channel = self.make_request(
-            "GET", sync_url % (access_token, next_batch)
-        )
+        channel = self.make_request("GET", sync_url % (access_token, next_batch))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
@@ -299,9 +293,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         # ahead, and therefore it's saying the typing (that we've actually
         # already seen) is new, since it's got a token above our new, now-reset
         # stream token.
-        request, channel = self.make_request(
-            "GET", sync_url % (access_token, next_batch)
-        )
+        channel = self.make_request("GET", sync_url % (access_token, next_batch))
         self.assertEquals(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
@@ -383,7 +375,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
 
         # Send a read receipt to tell the server we've read the latest event.
         body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/rooms/%s/read_markers" % self.room_id,
             body,
@@ -450,7 +442,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
     def _check_unread_count(self, expected_count: True):
         """Syncs and compares the unread count with the expected value."""
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", self.url % self.next_batch, access_token=self.tok,
         )
 
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index fbcf8d5b86..5e90d656f7 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -39,7 +39,7 @@ from tests.utils import default_config
 class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
-        return self.setup_test_homeserver(http_client=self.http_client)
+        return self.setup_test_homeserver(federation_http_client=self.http_client)
 
     def create_test_resource(self):
         return create_resource_tree(
@@ -172,7 +172,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             }
         ]
         self.hs2 = self.setup_test_homeserver(
-            http_client=self.http_client2, config=config
+            federation_http_client=self.http_client2, config=config
         )
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 2a3b2a8f27..a6c6985173 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         config = self.default_config()
         config["media_store_path"] = self.media_store_path
-        config["thumbnail_requirements"] = {}
         config["max_image_pixels"] = 2000000
 
         provider_config = {
@@ -214,7 +213,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         }
         config["media_storage_providers"] = [provider_config]
 
-        hs = self.setup_test_homeserver(config=config, http_client=client)
+        hs = self.setup_test_homeserver(config=config, federation_http_client=client)
 
         return hs
 
@@ -228,7 +227,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
     def _req(self, content_disposition):
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.download_resource),
             "GET",
@@ -313,18 +312,42 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
 
     def test_thumbnail_crop(self):
+        """Test that a cropped remote thumbnail is available."""
         self._test_thumbnail(
             "crop", self.test_image.expected_cropped, self.test_image.expected_found
         )
 
     def test_thumbnail_scale(self):
+        """Test that a scaled remote thumbnail is available."""
         self._test_thumbnail(
             "scale", self.test_image.expected_scaled, self.test_image.expected_found
         )
 
+    def test_invalid_type(self):
+        """An invalid thumbnail type is never available."""
+        self._test_thumbnail("invalid", None, False)
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
+    )
+    def test_no_thumbnail_crop(self):
+        """
+        Override the config to generate only scaled thumbnails, but request a cropped one.
+        """
+        self._test_thumbnail("crop", None, False)
+
+    @unittest.override_config(
+        {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
+    )
+    def test_no_thumbnail_scale(self):
+        """
+        Override the config to generate only cropped thumbnails, but request a scaled one.
+        """
+        self._test_thumbnail("scale", None, False)
+
     def _test_thumbnail(self, method, expected_body, expected_found):
         params = "?width=32&height=32&method=" + method
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(self.thumbnail_resource),
             "GET",
@@ -362,3 +385,16 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                     "error": "Not found [b'example.com', b'12345']",
                 },
             )
+
+    def test_x_robots_tag_header(self):
+        """
+        Tests that the `X-Robots-Tag` header is present, which informs web crawlers
+        to not index, archive, or follow links in media.
+        """
+        channel = self._req(b"inline; filename=out" + self.test_image.extension)
+
+        headers = channel.headers
+        self.assertEqual(
+            headers.getRawHeaders(b"X-Robots-Tag"),
+            [b"noindex, nofollow, noarchive, noimageindex"],
+        )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..6968502433 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,42 +18,23 @@ import re
 
 from mock import patch
 
-import attr
-
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
 
 from tests import unittest
 from tests.server import FakeTransport
 
-
-@attr.s
-class FakeResponse:
-    version = attr.ib()
-    code = attr.ib()
-    phrase = attr.ib()
-    headers = attr.ib()
-    body = attr.ib()
-    absoluteURI = attr.ib()
-
-    @property
-    def request(self):
-        @attr.s
-        class FakeTransport:
-            absoluteURI = self.absoluteURI
-
-        return FakeTransport()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
+try:
+    import lxml
+except ImportError:
+    lxml = None
 
 
 class URLPreviewTests(unittest.HomeserverTestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
 
     hijack_auth = True
     user_id = "@test:user"
@@ -139,7 +120,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
     def test_cache_returns_correct_type(self):
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -164,7 +145,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
 
         # Check the cache returns the correct response
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://matrix.org", shorthand=False
         )
 
@@ -180,7 +161,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertNotIn("http://matrix.org", self.preview_url._cache)
 
         # Check the database cache returns the correct response
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://matrix.org", shorthand=False
         )
 
@@ -201,7 +182,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             b"</head></html>"
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -236,7 +217,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             b"</head></html>"
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -271,7 +252,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             b"</head></html>"
         )
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://matrix.org",
             shorthand=False,
@@ -304,7 +285,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://example.com",
             shorthand=False,
@@ -334,7 +315,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -355,7 +336,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -372,7 +353,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         Blacklisted IP addresses, accessed directly, are not spidered.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://192.168.1.1", shorthand=False
         )
 
@@ -391,7 +372,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         Blacklisted IP ranges, accessed directly, are not spidered.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://1.1.1.2", shorthand=False
         )
 
@@ -411,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://example.com",
             shorthand=False,
@@ -448,7 +429,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             (IPv4Address, "10.1.2.3"),
         ]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
         self.assertEqual(channel.code, 502)
@@ -468,7 +449,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
         ]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -489,7 +470,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "preview_url?url=http://example.com", shorthand=False
         )
 
@@ -506,7 +487,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         OPTIONS returns the OPTIONS.
         """
-        request, channel = self.make_request(
+        channel = self.make_request(
             "OPTIONS", "preview_url?url=http://example.com", shorthand=False
         )
         self.assertEqual(channel.code, 200)
@@ -519,7 +500,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
         # Build and make a request to the server
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET",
             "preview_url?url=http://example.com",
             shorthand=False,
@@ -593,7 +574,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 b"</head></html>"
             )
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
                 shorthand=False,
@@ -658,7 +639,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             }
             end_content = json.dumps(result).encode("utf-8")
 
-            request, channel = self.make_request(
+            channel = self.make_request(
                 "GET",
                 "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
                 shorthand=False,
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 02a46e5fda..32acd93dc1 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -25,7 +25,7 @@ class HealthCheckTests(unittest.HomeserverTestCase):
         return HealthResource()
 
     def test_health(self):
-        request, channel = self.make_request("GET", "/health", shorthand=False)
+        channel = self.make_request("GET", "/health", shorthand=False)
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 6a930f4148..14de0921be 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -28,11 +28,11 @@ class WellKnownTests(unittest.HomeserverTestCase):
         self.hs.config.public_baseurl = "https://tesths"
         self.hs.config.default_identity_server = "https://testis"
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.code, 200)
         self.assertEqual(
             channel.json_body,
             {
@@ -44,8 +44,8 @@ class WellKnownTests(unittest.HomeserverTestCase):
     def test_well_known_no_public_baseurl(self):
         self.hs.config.public_baseurl = None
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(request.code, 404)
+        self.assertEqual(channel.code, 404)
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, Optional, Tuple, Union
+from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -47,13 +47,26 @@ class FakeChannel:
     site = attr.ib(type=Site)
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
+    _ip = attr.ib(type=str, default="127.0.0.1")
     _producer = None
 
     @property
     def json_body(self):
-        if not self.result:
-            raise Exception("No result yet.")
-        return json.loads(self.result["body"].decode("utf8"))
+        return json.loads(self.text_body)
+
+    @property
+    def text_body(self) -> str:
+        """The body of the result, utf-8-decoded.
+
+        Raises an exception if the request has not yet completed.
+        """
+        if not self.is_finished:
+            raise Exception("Request not yet completed")
+        return self.result["body"].decode("utf8")
+
+    def is_finished(self) -> bool:
+        """check if the response has been completely received"""
+        return self.result.get("done", False)
 
     @property
     def code(self):
@@ -62,7 +75,7 @@ class FakeChannel:
         return int(self.result["code"])
 
     @property
-    def headers(self):
+    def headers(self) -> Headers:
         if not self.result:
             raise Exception("No result yet.")
         h = Headers()
@@ -108,7 +121,7 @@ class FakeChannel:
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
         # causing us to record the MAU
-        return address.IPv4Address("TCP", "127.0.0.1", 3423)
+        return address.IPv4Address("TCP", self._ip, 3423)
 
     def getHost(self):
         return None
@@ -124,7 +137,7 @@ class FakeChannel:
         self._reactor.run()
         x = 0
 
-        while not self.result.get("done"):
+        while not self.is_finished():
             # If there's a producer, tell it to resume producing so we get content
             if self._producer:
                 self._producer.resumeProducing()
@@ -136,6 +149,16 @@ class FakeChannel:
 
             self._reactor.advance(0.1)
 
+    def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
+        """Process the contents of any Set-Cookie headers in the response
+
+        Any cookines found are added to the given dict
+        """
+        for h in self.headers.getRawHeaders("Set-Cookie"):
+            parts = h.split(";")
+            k, v = parts[0].split("=", maxsplit=1)
+            cookies[k] = v
+
 
 class FakeSite:
     """
@@ -174,11 +197,12 @@ def make_request(
     custom_headers: Optional[
         Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
     ] = None,
-):
+    client_ip: str = "127.0.0.1",
+) -> FakeChannel:
     """
     Make a web request using the given method, path and content, and render it
 
-    Returns the Request and the Channel underneath.
+    Returns the fake Channel object which records the response to the request.
 
     Args:
         site: The twisted Site to use to render the request
@@ -201,8 +225,11 @@ def make_request(
              will pump the reactor until the the renderer tells the channel the request
              is finished.
 
+        client_ip: The IP to use as the requesting IP. Useful for testing
+            ratelimiting.
+
     Returns:
-        Tuple[synapse.http.site.SynapseRequest, channel]
+        channel
     """
     if not isinstance(method, bytes):
         method = method.encode("ascii")
@@ -216,8 +243,9 @@ def make_request(
         and not path.startswith(b"/_matrix")
         and not path.startswith(b"/_synapse")
     ):
+        if path.startswith(b"/"):
+            path = path[1:]
         path = b"/_matrix/client/r0/" + path
-        path = path.replace(b"//", b"/")
 
     if not path.startswith(b"/"):
         path = b"/" + path
@@ -227,7 +255,7 @@ def make_request(
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    channel = FakeChannel(site, reactor)
+    channel = FakeChannel(site, reactor, ip=client_ip)
 
     req = request(channel)
     req.content = BytesIO(content)
@@ -258,12 +286,13 @@ def make_request(
         for k, v in custom_headers:
             req.requestHeaders.addRawHeader(k, v)
 
+    req.parseCookies()
     req.requestReceived(method, path, b"1.1")
 
     if await_result:
         channel.await_result()
 
-    return req, channel
+    return channel
 
 
 @implementer(IReactorPluggableNameResolver)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index e0a9cd93ac..4dd5a36178 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -70,7 +70,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         the notice URL + an authentication code.
         """
         # Initial sync, to get the user consent room invite
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/sync", access_token=self.access_token
         )
         self.assertEqual(channel.code, 200)
@@ -79,7 +79,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         room_id = list(channel.json_body["rooms"]["invite"].keys())[0]
 
         # Join the room
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST",
             "/_matrix/client/r0/rooms/" + room_id + "/join",
             access_token=self.access_token,
@@ -87,7 +87,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
 
         # Sync again, to get the message in the room
-        request, channel = self.make_request(
+        channel = self.make_request(
             "GET", "/_matrix/client/r0/sync", access_token=self.access_token
         )
         self.assertEqual(channel.code, 200)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 9c8027a5b2..fea54464af 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -305,7 +305,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         self.register_user("user", "password")
         tok = self.login("user", "password")
 
-        request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+        channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
 
         invites = channel.json_body["rooms"]["invite"]
         self.assertEqual(len(invites), 0, invites)
@@ -318,7 +318,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         # Sync again to retrieve the events in the room, so we can check whether this
         # room has a notice in it.
-        request, channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
+        channel = self.make_request("GET", "/sync?timeout=0", access_token=tok)
 
         # Scan the events in the room to search for a message from the server notices
         # user.
@@ -353,9 +353,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
             tok = self.login(localpart, "password")
 
             # Sync with the user's token to mark the user as active.
-            request, channel = self.make_request(
-                "GET", "/sync?timeout=0", access_token=tok,
-            )
+            channel = self.make_request("GET", "/sync?timeout=0", access_token=tok,)
 
             # Also retrieves the list of invites for this user. We don't care about that
             # one except if we're processing the last user, which should have received an
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..77c72834f2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
 from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+    _get_auth_chain_difference,
+    lexicographical_topological_sort,
+    resolve_events_with_store,
+)
 from synapse.types import EventID
 
 from tests import unittest
@@ -84,7 +88,7 @@ class FakeEvent:
         event_dict = {
             "auth_events": [(a, {}) for a in auth_events],
             "prev_events": [(p, {}) for p in prev_events],
-            "event_id": self.node_id,
+            "event_id": self.event_id,
             "sender": self.sender,
             "type": self.type,
             "content": self.content,
@@ -377,6 +381,61 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
+    def test_mainline_sort(self):
+        """Tests that the mainline ordering works correctly.
+        """
+
+        events = [
+            FakeEvent(
+                id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+            ),
+            FakeEvent(
+                id="PA1",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={"users": {ALICE: 100, BOB: 50}},
+            ),
+            FakeEvent(
+                id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+            ),
+            FakeEvent(
+                id="PA2",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={
+                    "users": {ALICE: 100, BOB: 50},
+                    "events": {EventTypes.PowerLevels: 100},
+                },
+            ),
+            FakeEvent(
+                id="PB",
+                sender=BOB,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={"users": {ALICE: 100, BOB: 50}},
+            ),
+            FakeEvent(
+                id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
+            ),
+            FakeEvent(
+                id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+            ),
+        ]
+
+        edges = [
+            ["END", "T3", "PA2", "T2", "PA1", "T1", "START"],
+            ["END", "T4", "PB", "PA1"],
+        ]
+
+        # We expect T3 to be picked as the other topics are pointing at older
+        # power levels. Note that without mainline ordering we'd pick T4 due to
+        # it being sent *after* T3.
+        expected_state_ids = ["T3", "PA2"]
+
+        self.do_check(events, edges, expected_state_ids)
+
     def do_check(self, events, edges, expected_state_ids):
         """Take a list of events and edges and calculate the state of the
         graph at END, and asserts it matches `expected_state_ids`
@@ -587,6 +646,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
         self.assert_dict(self.expected_combined_state, state)
 
 
+class AuthChainDifferenceTestCase(unittest.TestCase):
+    """We test that `_get_auth_chain_difference` correctly handles unpersisted
+    events.
+    """
+
+    def test_simple(self):
+        # Test getting the auth difference for a simple chain with a single
+        # unpersisted event:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #           C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c}
+
+        state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {c.event_id})
+
+    def test_multiple_unpersisted_chain(self):
+        # Test getting the auth difference for a simple chain with multiple
+        # unpersisted events:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #      D -> C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, c.event_id})
+
+    def test_unpersisted_events_different_sets(self):
+        # Test getting the auth difference for with multiple unpersisted events
+        # in different branches:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #     D --> C -|-> B -> A
+        #     E ----^ -|---^
+        #              |
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        e = FakeEvent(
+            id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id, b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, e.event_id})
+
+
 def pairwise(iterable):
     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
     a, b = itertools.tee(iterable)
@@ -647,7 +834,7 @@ class TestStateResolutionStore:
 
         return list(result)
 
-    def get_auth_chain_difference(self, auth_sets):
+    def get_auth_chain_difference(self, room_id, auth_sets):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
new file mode 100644
index 0000000000..673e1fe3e3
--- /dev/null
+++ b/tests/storage/test_account_data.py
@@ -0,0 +1,120 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Iterable, Set
+
+from synapse.api.constants import AccountDataTypes
+
+from tests import unittest
+
+
+class IgnoredUsersTestCase(unittest.HomeserverTestCase):
+    def prepare(self, hs, reactor, clock):
+        self.store = self.hs.get_datastore()
+        self.user = "@user:test"
+
+    def _update_ignore_list(
+        self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+    ) -> None:
+        """Update the account data to block the given users."""
+        if ignorer_user_id is None:
+            ignorer_user_id = self.user
+
+        self.get_success(
+            self.store.add_account_data_for_user(
+                ignorer_user_id,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": {u: {} for u in ignored_user_ids}},
+            )
+        )
+
+    def assert_ignorers(
+        self, ignored_user_id: str, expected_ignorer_user_ids: Set[str]
+    ) -> None:
+        self.assertEqual(
+            self.get_success(self.store.ignored_by(ignored_user_id)),
+            expected_ignorer_user_ids,
+        )
+
+    def test_ignoring_users(self):
+        """Basic adding/removing of users from the ignore list."""
+        self._update_ignore_list("@other:test", "@another:remote")
+
+        # Check a user which no one ignores.
+        self.assert_ignorers("@user:test", set())
+
+        # Check a local user which is ignored.
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Check a remote user which is ignored.
+        self.assert_ignorers("@another:remote", {self.user})
+
+        # Add one user, remove one user, and leave one user.
+        self._update_ignore_list("@foo:test", "@another:remote")
+
+        # Check the removed user.
+        self.assert_ignorers("@other:test", set())
+
+        # Check the added user.
+        self.assert_ignorers("@foo:test", {self.user})
+
+        # Check the removed user.
+        self.assert_ignorers("@another:remote", {self.user})
+
+    def test_caching(self):
+        """Ensure that caching works properly between different users."""
+        # The first user ignores a user.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # The second user ignores them.
+        self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+        self.assert_ignorers("@other:test", {self.user, "@second:test"})
+
+        # The first user un-ignores them.
+        self._update_ignore_list()
+        self.assert_ignorers("@other:test", {"@second:test"})
+
+    def test_invalid_data(self):
+        """Invalid data ends up clearing out the ignored users list."""
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # No ignored_users key.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user, AccountDataTypes.IGNORED_USER_LIST, {},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
+
+        # Add some data and ensure it is there.
+        self._update_ignore_list("@other:test")
+        self.assert_ignorers("@other:test", {self.user})
+
+        # Invalid data.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": "unexpected"},
+            )
+        )
+
+        # No one ignores the user now.
+        self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ecb00f4e02..dabc1c5f09 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -80,6 +80,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_count_devices_by_users(self):
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device1", "display_name 1")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device2", "display_name 2")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id2", "device3", "display_name 3")
+        )
+
+        res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+        self.assertEqual(0, res)
+
+        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+        self.assertEqual(0, res)
+
+        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+        self.assertEqual(2, res)
+
+        res = yield defer.ensureDeferred(
+            self.store.count_devices_by_users(["user_id", "user_id2"])
+        )
+        self.assertEqual(3, res)
+
+    @defer.inlineCallbacks
     def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 35dafbb904..3d7760d5d9 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -26,7 +26,7 @@ room_key = {
 
 class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.store = hs.get_datastore()
         return hs
 
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..0c46ad595b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,741 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Set, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self._next_stream_ordering = 1
+
+    def test_simple(self):
+        """Test that the example in `docs/auth_chain_difference_algorithm.md`
+        works.
+        """
+
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        power_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        bob_join_2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join_2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[
+                    create.event_id,
+                    alice_join.event_id,
+                    power_2.event_id,
+                ],
+            )
+        )
+
+        events = [
+            create,
+            bob_join,
+            power,
+            alice_invite,
+            alice_join,
+            bob_join_2,
+            power_2,
+            alice_join2,
+        ]
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+            (bob_join_2, power),
+            (alice_join2, power_2),
+        ]
+
+        self.persist(events)
+        chain_map, link_map = self.fetch_chains(events)
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+        # Test that everything can reach the create event, but the create event
+        # can't reach anything.
+        for event in events[1:]:
+            self.assertTrue(
+                link_map.exists_path_from(
+                    chain_map[event.event_id], chain_map[create.event_id]
+                ),
+            )
+
+            self.assertFalse(
+                link_map.exists_path_from(
+                    chain_map[create.event_id], chain_map[event.event_id],
+                ),
+            )
+
+    def test_out_of_order_events(self):
+        """Test that we handle persisting events that we don't have the full
+        auth chain for yet (which should only happen for out of band memberships).
+        """
+        event_factory = self.hs.get_event_builder_factory()
+        bob = "@creator:test"
+        alice = "@alice:test"
+        room_id = "!room:test"
+
+        # Ensure that we have a rooms entry so that we generate the chain index.
+        self.get_success(
+            self.store.store_room(
+                room_id=room_id,
+                room_creator_user_id="",
+                is_public=True,
+                room_version=RoomVersions.V6,
+            )
+        )
+
+        # First persist the base room.
+        create = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Create,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "create"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[])
+        )
+
+        bob_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": bob,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "bob_join"},
+                },
+            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+        )
+
+        power = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.PowerLevels,
+                    "state_key": "",
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "power"},
+                },
+            ).build(
+                prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+            )
+        )
+
+        self.persist([create, bob_join, power])
+
+        # Now persist an invite and a couple of memberships out of order.
+        alice_invite = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": bob,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_invite"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+            )
+        )
+
+        alice_join = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+            )
+        )
+
+        alice_join2 = self.get_success(
+            event_factory.for_room_version(
+                RoomVersions.V6,
+                {
+                    "type": EventTypes.Member,
+                    "state_key": alice,
+                    "sender": alice,
+                    "room_id": room_id,
+                    "content": {"tag": "alice_join2"},
+                },
+            ).build(
+                prev_event_ids=[],
+                auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+            )
+        )
+
+        self.persist([alice_join])
+        self.persist([alice_join2])
+        self.persist([alice_invite])
+
+        # The end result should be sane.
+        events = [create, bob_join, power, alice_invite, alice_join]
+
+        chain_map, link_map = self.fetch_chains(events)
+
+        expected_links = [
+            (bob_join, create),
+            (power, create),
+            (power, bob_join),
+            (alice_invite, create),
+            (alice_invite, power),
+            (alice_invite, bob_join),
+        ]
+
+        # Check that the expected links and only the expected links have been
+        # added.
+        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+        for start, end in expected_links:
+            start_id, start_seq = chain_map[start.event_id]
+            end_id, end_seq = chain_map[end.event_id]
+
+            self.assertIn(
+                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+            )
+
+    def persist(
+        self, events: List[EventBase],
+    ):
+        """Persist the given events and check that the links generated match
+        those given.
+        """
+
+        persist_events_store = self.hs.get_datastores().persist_events
+
+        for e in events:
+            e.internal_metadata.stream_ordering = self._next_stream_ordering
+            self._next_stream_ordering += 1
+
+        def _persist(txn):
+            # We need to persist the events to the events and state_events
+            # tables.
+            persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+            # Actually call the function that calculates the auth chain stuff.
+            persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+        self.get_success(
+            persist_events_store.db_pool.runInteraction("_persist", _persist,)
+        )
+
+    def fetch_chains(
+        self, events: List[EventBase]
+    ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+        # Fetch the map from event ID -> (chain ID, sequence number)
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chains",
+                column="event_id",
+                iterable=[e.event_id for e in events],
+                retcols=("event_id", "chain_id", "sequence_number"),
+                keyvalues={},
+            )
+        )
+
+        chain_map = {
+            row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+        }
+
+        # Fetch all the links and pass them to the _LinkMap.
+        rows = self.get_success(
+            self.store.db_pool.simple_select_many_batch(
+                table="event_auth_chain_links",
+                column="origin_chain_id",
+                iterable=[chain_id for chain_id, _ in chain_map.values()],
+                retcols=(
+                    "origin_chain_id",
+                    "origin_sequence_number",
+                    "target_chain_id",
+                    "target_sequence_number",
+                ),
+                keyvalues={},
+            )
+        )
+
+        link_map = _LinkMap()
+        for row in rows:
+            added = link_map.add_link(
+                (row["origin_chain_id"], row["origin_sequence_number"]),
+                (row["target_chain_id"], row["target_sequence_number"]),
+            )
+
+            # We shouldn't have persisted any redundant links
+            self.assertTrue(added)
+
+        return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+    def test_simple(self):
+        """Basic tests for the LinkMap.
+        """
+        link_map = _LinkMap()
+
+        link_map.add_link((1, 1), (2, 1), new=False)
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+        self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+        self.assertCountEqual(link_map.get_additions(), [])
+        self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+        self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+        self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+        # Attempting to add a redundant link is ignored.
+        self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+        # Adding new non-redundant links works
+        self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+        self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+        self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.user_id = self.register_user("foo", "pass")
+        self.token = self.login("foo", "pass")
+        self.requester = create_requester(self.user_id)
+
+    def _generate_room(self) -> Tuple[str, List[Set[str]]]:
+        """Insert a room without a chain cover index.
+        """
+        room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+        # Mark the room as not having a chain cover index
+        self.get_success(
+            self.store.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+                desc="test",
+            )
+        )
+
+        # Create a fork in the DAG with different events.
+        event_handler = self.hs.get_event_creation_handler()
+        latest_event_ids = self.get_success(
+            self.store.get_prev_events_for_room(room_id)
+        )
+        event, context = self.get_success(
+            event_handler.create_event(
+                self.requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": self.user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(self.requester, event, context)
+        )
+        state1 = set(self.get_success(context.get_current_state_ids()).values())
+
+        event, context = self.get_success(
+            event_handler.create_event(
+                self.requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": self.user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(self.requester, event, context)
+        )
+        state2 = set(self.get_success(context.get_current_state_ids()).values())
+
+        # Delete the chain cover info.
+
+        def _delete_tables(txn):
+            txn.execute("DELETE FROM event_auth_chains")
+            txn.execute("DELETE FROM event_auth_chain_links")
+
+        self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
+
+        return room_id, [state1, state2]
+
+    def test_background_update_single_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                states,
+            )
+        )
+
+    def test_background_update_multiple_rooms(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+        # Create a room
+        room_id1, states1 = self._generate_room()
+        room_id2, states2 = self._generate_room()
+        room_id3, states2 = self._generate_room()
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id1,
+                states1,
+            )
+        )
+
+    def test_background_update_single_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                states,
+            )
+        )
+
+    def test_background_update_multiple_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create the rooms
+        room_id1, _ = self._generate_room()
+        room_id2, _ = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id1, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id2, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
 import tests.unittest
 import tests.utils
 
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
 
-    def test_auth_difference(self):
+    @parameterized.expand([(True,), (False,)])
+    def test_auth_difference(self, use_chain_cover_index: bool):
         room_id = "@ROOM:local"
 
         # The silly auth graph we use to test the auth difference algorithm,
@@ -159,77 +165,279 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             "j": 1,
         }
 
+        # Mark the room as not having a cover index
+
+        def store_room(txn):
+            self.store.db_pool.simple_insert_txn(
+                txn,
+                "rooms",
+                {
+                    "room_id": room_id,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": use_chain_cover_index,
+                },
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
         # We rudely fiddle with the appropriate tables directly, as that's much
         # easier than constructing events properly.
 
-        def insert_event(txn, event_id, stream_ordering):
+        def insert_event(txn):
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn,
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                ],
+            )
+
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
+        self.assertSetEqual(difference, set())
+
+    def test_auth_difference_partial_cover(self):
+        """Test that we correctly handle rooms where not all events have a chain
+        cover calculated. This can happen in some obscure edge cases, including
+        during the background update that calculates the chain cover for old
+        rooms.
+        """
+
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ´ |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
 
-            depth = depth_map[event_id]
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
+
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
 
+        def insert_event(txn):
+            # First insert the room and mark it as having a chain cover.
             self.store.db_pool.simple_insert_txn(
                 txn,
-                table="events",
-                values={
-                    "event_id": event_id,
+                "rooms",
+                {
                     "room_id": room_id,
-                    "depth": depth,
-                    "topological_ordering": depth,
-                    "type": "m.test",
-                    "processed": True,
-                    "outlier": False,
-                    "stream_ordering": stream_ordering,
+                    "creator": "room_creator_user_id",
+                    "is_public": True,
+                    "room_version": "6",
+                    "has_auth_chain_index": True,
                 },
             )
 
-            self.store.db_pool.simple_insert_many_txn(
+            stream_ordering = 0
+
+            for event_id in auth_graph:
+                stream_ordering += 1
+                depth = depth_map[event_id]
+
+                self.store.db_pool.simple_insert_txn(
+                    txn,
+                    table="events",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "depth": depth,
+                        "topological_ordering": depth,
+                        "type": "m.test",
+                        "processed": True,
+                        "outlier": False,
+                        "stream_ordering": stream_ordering,
+                    },
+                )
+
+            # Insert all events apart from 'B'
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
                 txn,
-                table="event_auth",
-                values=[
-                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
-                    for a in auth_graph[event_id]
+                [
+                    FakeEvent(event_id, room_id, auth_graph[event_id])
+                    for event_id in auth_graph
+                    if event_id != "b"
                 ],
             )
 
-        next_stream_ordering = 0
-        for event_id in auth_graph:
-            next_stream_ordering += 1
-            self.get_success(
-                self.store.db_pool.runInteraction(
-                    "insert", insert_event, event_id, next_stream_ordering
-                )
+            # Now we insert the event 'B' without a chain cover, by temporarily
+            # pretending the room doesn't have a chain cover.
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+            )
+
+            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+                txn, [FakeEvent("b", room_id, auth_graph["b"])],
+            )
+
+            self.store.db_pool.simple_update_txn(
+                txn,
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": True},
             )
 
+        self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "d", "e"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
-        difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
         self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+    event_id = attr.ib()
+    room_id = attr.ib()
+    auth_events = attr.ib()
+
+    type = "foo"
+    state_key = "foo"
+
+    internal_metadata = _EventInternalMetadata({})
+
+    def auth_event_ids(self):
+        return self.auth_events
+
+    def is_state(self):
+        return True
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
new file mode 100644
index 0000000000..71210ce606
--- /dev/null
+++ b/tests/storage/test_events.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import RoomVersions
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.unittest import HomeserverTestCase
+
+
+class ExtremPruneTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.state = self.hs.get_state_handler()
+        self.persistence = self.hs.get_storage().persistence
+        self.store = self.hs.get_datastore()
+
+        self.register_user("user", "pass")
+        self.token = self.login("user", "pass")
+
+        self.room_id = self.helper.create_room_as(
+            "user", room_version=RoomVersions.V6.identifier, tok=self.token
+        )
+
+        body = self.helper.send(self.room_id, body="Test", tok=self.token)
+        local_message_event_id = body["event_id"]
+
+        # Fudge a remote event and persist it. This will be the extremity before
+        # the gap.
+        self.remote_event_1 = event_from_pdu_json(
+            {
+                "type": EventTypes.Message,
+                "state_key": "@user:other",
+                "content": {},
+                "room_id": self.room_id,
+                "sender": "@user:other",
+                "depth": 5,
+                "prev_events": [local_message_event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        self.persist_event(self.remote_event_1)
+
+        # Check that the current extremities is the remote event.
+        self.assert_extremities([self.remote_event_1.event_id])
+
+    def persist_event(self, event, state=None):
+        """Persist the event, with optional state
+        """
+        context = self.get_success(
+            self.state.compute_event_context(event, old_state=state)
+        )
+        self.get_success(self.persistence.persist_event(event, context))
+
+    def assert_extremities(self, expected_extremities):
+        """Assert the current extremities for the room
+        """
+        extremities = self.get_success(
+            self.store.get_prev_events_for_room(self.room_id)
+        )
+        self.assertCountEqual(extremities, expected_extremities)
+
+    def test_prune_gap(self):
+        """Test that we drop extremities after a gap when we see an event from
+        the same domain.
+        """
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other",
+                "depth": 50,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id])
+
+    def test_do_not_prune_gap_if_state_different(self):
+        """Test that we don't prune extremities after a gap if the resolved
+        state is different.
+        """
+
+        # Fudge a second event which points to an event we don't have.
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Message,
+                "state_key": "@user:other",
+                "content": {},
+                "room_id": self.room_id,
+                "sender": "@user:other",
+                "depth": 10,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        # Now we persist it with state with a dropped history visibility
+        # setting. The state resolution across the old and new event will then
+        # include it, and so the resolved state won't match the new state.
+        state_before_gap = dict(
+            self.get_success(self.state.get_current_state(self.room_id))
+        )
+        state_before_gap.pop(("m.room.history_visibility", ""))
+
+        context = self.get_success(
+            self.state.compute_event_context(
+                remote_event_2, old_state=state_before_gap.values()
+            )
+        )
+
+        self.get_success(self.persistence.persist_event(remote_event_2, context))
+
+        # Check that we haven't dropped the old extremity.
+        self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+    def test_prune_gap_if_old(self):
+        """Test that we drop extremities after a gap when the previous extremity
+        is "old"
+        """
+
+        # Advance the clock for many days to make the old extremity "old". We
+        # also set the depth to "lots".
+        self.reactor.advance(7 * 24 * 60 * 60)
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id])
+
+    def test_do_not_prune_gap_if_other_server(self):
+        """Test that we do not drop extremities after a gap when we see an event
+        from a different domain.
+        """
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
+
+    def test_prune_gap_if_dummy_remote(self):
+        """Test that we drop extremities after a gap when the previous extremity
+        is a local dummy event and only points to remote events.
+        """
+
+        body = self.helper.send_event(
+            self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+        )
+        local_message_event_id = body["event_id"]
+        self.assert_extremities([local_message_event_id])
+
+        # Advance the clock for many days to make the old extremity "old". We
+        # also set the depth to "lots".
+        self.reactor.advance(7 * 24 * 60 * 60)
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id])
+
+    def test_prune_gap_if_dummy_local(self):
+        """Test that we don't drop extremities after a gap when the previous
+        extremity is a local dummy event and points to local events.
+        """
+
+        body = self.helper.send(self.room_id, body="Test", tok=self.token)
+
+        body = self.helper.send_event(
+            self.room_id, type=EventTypes.Dummy, content={}, tok=self.token
+        )
+        local_message_event_id = body["event_id"]
+        self.assert_extremities([local_message_event_id])
+
+        # Advance the clock for many days to make the old extremity "old". We
+        # also set the depth to "lots".
+        self.reactor.advance(7 * 24 * 60 * 60)
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([remote_event_2.event_id, local_message_event_id])
+
+    def test_do_not_prune_gap_if_not_dummy(self):
+        """Test that we do not drop extremities after a gap when the previous extremity
+        is not a dummy event.
+        """
+
+        body = self.helper.send(self.room_id, body="test", tok=self.token)
+        local_message_event_id = body["event_id"]
+        self.assert_extremities([local_message_event_id])
+
+        # Fudge a second event which points to an event we don't have. This is a
+        # state event so that the state changes (otherwise we won't prune the
+        # extremity as they'll have the same state group).
+        remote_event_2 = event_from_pdu_json(
+            {
+                "type": EventTypes.Member,
+                "state_key": "@user:other2",
+                "content": {"membership": Membership.JOIN},
+                "room_id": self.room_id,
+                "sender": "@user:other2",
+                "depth": 10000,
+                "prev_events": ["$some_unknown_message"],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            RoomVersions.V6,
+        )
+
+        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+
+        self.persist_event(remote_event_2, state=state_before_gap.values())
+
+        # Check the new extremity is just the new remote event.
+        self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index cc0612cf65..3e2fd4da01 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
             )
@@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.db_pool,
                 stream_name="test_stream",
                 instance_name=instance_name,
-                table="foobar",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[("foobar", "instance_name", "stream_id")],
                 sequence_name="foobar_seq",
                 writers=writers,
                 positive=False,
@@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
+
+
+class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db_pool = self.store.db_pool  # type: DatabasePool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar1 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+        txn.execute(
+            """
+            CREATE TABLE foobar2 (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(
+        self, instance_name="master", writers=["master"]
+    ) -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db_pool,
+                stream_name="test_stream",
+                instance_name=instance_name,
+                tables=[
+                    ("foobar1", "instance_name", "stream_id"),
+                    ("foobar2", "instance_name", "stream_id"),
+                ],
+                sequence_name="foobar_seq",
+                writers=writers,
+            )
+
+        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+
+    def _insert_rows(
+        self,
+        table: str,
+        instance_name: str,
+        number: int,
+        update_stream_table: bool = True,
+    ):
+        """Insert N rows as the given instance, inserting with stream IDs pulled
+        from the postgres sequence.
+        """
+
+        def _insert(txn):
+            for _ in range(number):
+                txn.execute(
+                    "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
+                    (instance_name,),
+                )
+                if update_stream_table:
+                    txn.execute(
+                        """
+                        INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
+                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+                        """,
+                        (instance_name,),
+                    )
+
+        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+    def test_load_existing_stream(self):
+        """Test creating ID gens with multiple tables that have rows from after
+        the position in `stream_positions` table.
+        """
+        self._insert_rows("foobar1", "first", 3)
+        self._insert_rows("foobar2", "second", 3)
+        self._insert_rows("foobar2", "second", 1, update_stream_table=False)
+
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+
+        # The first ID gen will notice that it can advance its token to 7 as it
+        # has no in progress writes...
+        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+
+        # ... but the second ID gen doesn't know that.
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 7e7f1286d9..e9e3bca3bf 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -48,3 +48,10 @@ class DataStoreTestCase(unittest.TestCase):
 
         self.assertEquals(1, total)
         self.assertEquals(self.displayname, users.pop()["displayname"])
+
+        users, total = yield defer.ensureDeferred(
+            self.store.get_users_paginate(0, 10, name="BC", guests=False)
+        )
+
+        self.assertEquals(1, total)
+        self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 3fd0a38cf5..ea63bd56b4 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -48,6 +48,19 @@ class ProfileStoreTestCase(unittest.TestCase):
             ),
         )
 
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.u_frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_avatar_url(self):
         yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
@@ -66,3 +79,16 @@ class ProfileStoreTestCase(unittest.TestCase):
                 )
             ),
         )
+
+        # test set to None
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(self.u_frank.localpart, None)
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                )
+            )
+        )
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index cc1f3c53c5..a06ad2c03e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase):
     servlets = [room.register_servlets]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d4f9e809db..a6303bf0ee 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -14,9 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from mock import Mock
-
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -30,12 +27,10 @@ from tests.utils import create_room
 
 
 class RedactionTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
-        config = self.default_config()
+    def default_config(self):
+        config = super().default_config()
         config["redaction_retention_period"] = "30d"
-        return self.setup_test_homeserver(
-            resource_for_federation=Mock(), http_client=None, config=config
-        )
+        return config
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index ff972daeaa..d2aed66f6d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -14,8 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
-
 from synapse.api.constants import Membership
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client.v1 import login, room
@@ -34,12 +32,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(
-            resource_for_federation=Mock(), http_client=None
-        )
-        return hs
-
     def prepare(self, reactor, clock, hs: TestHomeServer):
 
         # We can't test the RoomMemberStore on its own without the other event
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 738e912468..a6f63f4aaf 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -21,6 +21,8 @@ from tests.utils import setup_test_homeserver
 ALICE = "@alice:a"
 BOB = "@bob:b"
 BOBBY = "@bobby:a"
+# The localpart isn't 'Bela' on purpose so we can test looking up display names.
+BELA = "@somenickname:a"
 
 
 class UserDirectoryStoreTestCase(unittest.TestCase):
@@ -41,6 +43,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
             self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
         )
         yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BELA, "Bela", None)
+        )
+        yield defer.ensureDeferred(
             self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
         )
 
@@ -72,3 +77,21 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
             )
         finally:
             self.hs.config.user_directory_search_all_users = False
+
+    @defer.inlineCallbacks
+    def test_search_user_dir_stop_words(self):
+        """Tests that a user can look up another user by searching for the start if its
+        display name even if that name happens to be a common English word that would
+        usually be ignored in full text searches.
+        """
+        self.hs.config.user_directory_search_all_users = True
+        try:
+            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
+            self.assertFalse(r["limited"])
+            self.assertEqual(1, len(r["results"]))
+            self.assertDictEqual(
+                r["results"][0],
+                {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+            )
+        finally:
+            self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 1ce4ea3a01..fc9aab32d0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -37,7 +37,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
             self.addCleanup,
-            http_client=self.http_client,
+            federation_http_client=self.http_client,
             clock=self.hs_clock,
             reactor=self.reactor,
         )
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        with LoggingContext(request="lying_event"):
+        with LoggingContext():
             failure = self.get_failure(
                 self.handler.on_receive_pdu(
                     "test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_mau.py b/tests/test_mau.py
index c5ec6396a7..51660b51d5 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -19,6 +19,7 @@ import json
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.rest.client.v2_alpha import register, sync
 
 from tests import unittest
@@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.code, 403)
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
+    def test_as_ignores_mau(self):
+        """Test that application services can still create users when the MAU
+        limit has been reached. This only works when application service
+        user ip tracking is disabled.
+        """
+
+        # Create and sync so that the MAU counts get updated
+        token1 = self.create_user("kermit1")
+        self.do_sync_for_user(token1)
+        token2 = self.create_user("kermit2")
+        self.do_sync_for_user(token2)
+
+        # check we're testing what we think we are: there should be two active users
+        self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
+
+        # We've created and activated two users, we shouldn't be able to
+        # register new users
+        with self.assertRaises(SynapseError) as cm:
+            self.create_user("kermit3")
+
+        e = cm.exception
+        self.assertEqual(e.code, 403)
+        self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+        # Cheekily add an application service that we use to register a new user
+        # with.
+        as_token = "foobartoken"
+        self.store.services_cache.append(
+            ApplicationService(
+                token=as_token,
+                hostname=self.hs.hostname,
+                id="SomeASID",
+                sender="@as_sender:test",
+                namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
+            )
+        )
+
+        self.create_user("as_kermit4", token=as_token)
+
     def test_allowed_after_a_month_mau(self):
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
@@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.reactor.advance(100)
         self.assertEqual(2, self.successResultOf(count))
 
-    def create_user(self, localpart):
+    def create_user(self, localpart, token=None):
         request_data = json.dumps(
             {
                 "username": localpart,
@@ -201,7 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
             }
         )
 
-        request, channel = self.make_request("POST", "/register", request_data)
+        channel = self.make_request(
+            "POST", "/register", request_data, access_token=token,
+        )
 
         if channel.code != 200:
             raise HttpResponseException(
@@ -213,7 +255,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         return access_token
 
     def do_sync_for_user(self, token):
-        request, channel = self.make_request("GET", "/sync", access_token=token)
+        channel = self.make_request("GET", "/sync", access_token=token)
 
         if channel.code != 200:
             raise HttpResponseException(
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 7f67ee9e1f..0c6cbbd921 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -20,8 +20,16 @@ from synapse.rest.media.v1.preview_url_resource import (
 
 from . import unittest
 
+try:
+    import lxml
+except ImportError:
+    lxml = None
+
 
 class PreviewTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_long_summarize(self):
         example_paras = [
             """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
@@ -56,7 +64,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -69,7 +77,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø lies in Northern Norway. The municipality has a population of"
             " (2015) 72,066, but with an annual influx of students it has over 75,000"
@@ -96,7 +104,7 @@ class PreviewTestCase(unittest.TestCase):
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
 
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -122,7 +130,7 @@ class PreviewTestCase(unittest.TestCase):
         ]
 
         desc = summarize_paragraphs(example_paras, min_size=200, max_size=500)
-        self.assertEquals(
+        self.assertEqual(
             desc,
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -137,6 +145,9 @@ class PreviewTestCase(unittest.TestCase):
 
 
 class PreviewUrlTestCase(unittest.TestCase):
+    if not lxml:
+        skip = "url preview feature requires lxml"
+
     def test_simple(self):
         html = """
         <html>
@@ -149,7 +160,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment(self):
         html = """
@@ -164,7 +175,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment2(self):
         html = """
@@ -182,7 +193,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(
+        self.assertEqual(
             og,
             {
                 "og:title": "Foo",
@@ -203,7 +214,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Foo", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_missing_title(self):
         html = """
@@ -216,7 +227,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
     def test_h1_as_title(self):
         html = """
@@ -230,7 +241,7 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": "Title", "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
     def test_missing_title_and_broken_h1(self):
         html = """
@@ -244,4 +255,38 @@ class PreviewUrlTestCase(unittest.TestCase):
 
         og = decode_and_calc_og(html, "http://example.com/test.html")
 
-        self.assertEquals(og, {"og:title": None, "og:description": "Some text."})
+        self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
+
+    def test_empty(self):
+        html = ""
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {})
+
+    def test_invalid_encoding(self):
+        """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
+        html = """
+        <html>
+        <head><title>Foo</title></head>
+        <body>
+        Some text.
+        </body>
+        </html>
+        """
+        og = decode_and_calc_og(
+            html, "http://example.com/test.html", "invalid-encoding"
+        )
+        self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
+
+    def test_invalid_encoding2(self):
+        """A body which doesn't match the sent character encoding."""
+        # Note that this contains an invalid UTF-8 sequence in the title.
+        html = b"""
+        <html>
+        <head><title>\xff\xff Foo</title></head>
+        <body>
+        Some text.
+        </body>
+        </html>
+        """
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
diff --git a/tests/test_server.py b/tests/test_server.py
index c387a85f2e..815da18e65 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -38,7 +38,10 @@ class JsonResourceTests(unittest.TestCase):
         self.reactor = ThreadedMemoryReactorClock()
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
-            self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
+            self.addCleanup,
+            federation_http_client=None,
+            clock=self.hs_clock,
+            reactor=self.reactor,
         )
 
     def test_handler_for_request(self):
@@ -61,11 +64,10 @@ class JsonResourceTests(unittest.TestCase):
             "test_servlet",
         )
 
-        request, channel = make_request(
+        make_request(
             self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
         )
 
-        self.assertEqual(request.args, {b"a": ["\N{SNOWMAN}".encode("utf8")]})
         self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
 
     def test_callback_direct_exception(self):
@@ -82,7 +84,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"500")
 
@@ -106,7 +108,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"500")
 
@@ -124,7 +126,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"403")
         self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -146,9 +148,7 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        _, channel = make_request(
-            self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
-        )
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
 
         self.assertEqual(channel.result["code"], b"400")
         self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -170,7 +170,7 @@ class JsonResourceTests(unittest.TestCase):
         )
 
         # The path was registered as GET, but this is a HEAD request.
-        _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
+        channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertNotIn("body", channel.result)
@@ -202,7 +202,7 @@ class OptionsResourceTests(unittest.TestCase):
         )
 
         # render the request and return the channel
-        _, channel = make_request(self.reactor, site, method, path, shorthand=False)
+        channel = make_request(self.reactor, site, method, path, shorthand=False)
         return channel
 
     def test_unknown_options_request(self):
@@ -275,7 +275,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
 
         self.assertEqual(channel.result["code"], b"200")
         body = channel.result["body"]
@@ -293,7 +293,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
 
         self.assertEqual(channel.result["code"], b"301")
         headers = channel.result["headers"]
@@ -314,7 +314,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
 
         self.assertEqual(channel.result["code"], b"304")
         headers = channel.result["headers"]
@@ -333,7 +333,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        _, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
+        channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertNotIn("body", channel.result)
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 71580b454d..a743cdc3a9 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -53,7 +53,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
     def test_ui_auth(self):
         # Do a UI auth request
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEquals(channel.result["code"], b"401", channel.result)
 
@@ -96,7 +96,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
 
         self.registration_handler.check_username = Mock(return_value=True)
 
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         # We don't bother checking that the response is correct - we'll leave that to
         # other tests. We just want to make sure we're on the right path.
@@ -113,7 +113,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
                 },
             }
         )
-        request, channel = self.make_request(b"POST", self.url, request_data)
+        channel = self.make_request(b"POST", self.url, request_data)
 
         # We're interested in getting a response that looks like a successful
         # registration, not so much that the details are exactly what we want.
diff --git a/tests/test_types.py b/tests/test_types.py
index 480bea1bdc..acdeea7a09 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -58,6 +58,10 @@ class RoomAliasTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(room.to_string(), "#channel:my.domain")
 
+    def test_validate(self):
+        id_string = "#test:domain,test"
+        self.assertFalse(RoomAlias.is_valid(id_string))
+
 
 class GroupIDTestCase(unittest.TestCase):
     def test_parse(self):
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,13 @@ import warnings
 from asyncio import Future
 from typing import Any, Awaitable, Callable, TypeVar
 
+from mock import Mock
+
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
 TV = TypeVar("TV")
 
 
@@ -80,3 +87,35 @@ def setup_awaitable_errors() -> Callable[[], None]:
     sys.unraisablehook = unraisablehook  # type: ignore
 
     return cleanup
+
+
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+    # AsyncMock is not available in python3.5, this mimics part of its behaviour
+    async def cb(*args, **kwargs):
+        if raises:
+            raise raises
+        return return_value
+
+    return Mock(side_effect=cb)
+
+
+@attr.s
+class FakeResponse:
+    """A fake twisted.web.IResponse object
+
+    there is a similar class at treq.test.test_response, but it lacks a `phrase`
+    attribute, and didn't support deliverBody until recently.
+    """
+
+    # HTTP response code
+    code = attr.ib(type=int)
+
+    # HTTP response phrase (eg b'OK' for a 200)
+    phrase = attr.ib(type=bytes)
+
+    # body of the response
+    body = attr.ib(type=bytes)
+
+    def deliverBody(self, protocol):
+        protocol.dataReceived(self.body)
+        protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
new file mode 100644
index 0000000000..ad563eb3f0
--- /dev/null
+++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from html.parser import HTMLParser
+from typing import Dict, Iterable, List, Optional, Tuple
+
+
+class TestHtmlParser(HTMLParser):
+    """A generic HTML page parser which extracts useful things from the HTML"""
+
+    def __init__(self):
+        super().__init__()
+
+        # a list of links found in the doc
+        self.links = []  # type: List[str]
+
+        # the values of any hidden <input>s: map from name to value
+        self.hiddens = {}  # type: Dict[str, Optional[str]]
+
+        # the values of any radio buttons: map from name to list of values
+        self.radios = {}  # type: Dict[str, List[Optional[str]]]
+
+    def handle_starttag(
+        self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+    ) -> None:
+        attr_dict = dict(attrs)
+        if tag == "a":
+            href = attr_dict["href"]
+            if href:
+                self.links.append(href)
+        elif tag == "input":
+            input_name = attr_dict.get("name")
+            if attr_dict["type"] == "radio":
+                assert input_name
+                self.radios.setdefault(input_name, []).append(attr_dict["value"])
+            elif attr_dict["type"] == "hidden":
+                assert input_name
+                self.hiddens[input_name] = attr_dict["value"]
+
+    def error(_, message):
+        raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging():
     handler = ToTwistedHandler()
     formatter = logging.Formatter(log_format)
     handler.setFormatter(formatter)
-    handler.addFilter(LoggingContextFilter(request=""))
+    handler.addFilter(LoggingContextFilter())
     root_logger.addHandler(handler)
 
     log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock, patch
 
@@ -46,6 +46,7 @@ from synapse.logging.context import (
 )
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
         """
         Create a the root resource for the test server.
 
-        The default implementation creates a JsonResource and calls each function in
-        `servlets` to register servletes against it
+        The default calls `self.create_resource_dict` and builds the resultant dict
+        into a tree.
         """
-        resource = JsonResource(self.hs)
+        root_resource = Resource()
+        create_resource_tree(self.create_resource_dict(), root_resource)
+        return root_resource
 
-        for servlet in self.servlets:
-            servlet(self.hs, resource)
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        """Create a resource tree for the test server
 
-        return resource
+        A resource tree is a mapping from path to twisted.web.resource.
+
+        The default implementation creates a JsonResource and calls each function in
+        `servlets` to register servlets against it.
+        """
+        servlet_resource = JsonResource(self.hs)
+        for servlet in self.servlets:
+            servlet(self.hs, servlet_resource)
+        return {
+            "/_matrix/client": servlet_resource,
+            "/_synapse/admin": servlet_resource,
+        }
 
     def default_config(self):
         """
@@ -358,24 +372,6 @@ class HomeserverTestCase(TestCase):
         Function to optionally be overridden in subclasses.
         """
 
-    # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
-    # when the `request` arg isn't given, so we define an explicit override to
-    # cover that case.
-    @overload
-    def make_request(
-        self,
-        method: Union[bytes, str],
-        path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
-        access_token: Optional[str] = None,
-        shorthand: bool = True,
-        federation_auth_origin: str = None,
-        content_is_form: bool = False,
-        await_result: bool = True,
-    ) -> Tuple[SynapseRequest, FakeChannel]:
-        ...
-
-    @overload
     def make_request(
         self,
         method: Union[bytes, str],
@@ -387,21 +383,11 @@ class HomeserverTestCase(TestCase):
         federation_auth_origin: str = None,
         content_is_form: bool = False,
         await_result: bool = True,
-    ) -> Tuple[T, FakeChannel]:
-        ...
-
-    def make_request(
-        self,
-        method: Union[bytes, str],
-        path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
-        access_token: Optional[str] = None,
-        request: Type[T] = SynapseRequest,
-        shorthand: bool = True,
-        federation_auth_origin: str = None,
-        content_is_form: bool = False,
-        await_result: bool = True,
-    ) -> Tuple[T, FakeChannel]:
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
+        client_ip: str = "127.0.0.1",
+    ) -> FakeChannel:
         """
         Create a SynapseRequest at the path using the method and containing the
         given content.
@@ -423,8 +409,13 @@ class HomeserverTestCase(TestCase):
                  true (the default), will pump the test reactor until the the renderer
                  tells the channel the request is finished.
 
+            custom_headers: (name, value) pairs to add as request headers
+
+            client_ip: The IP to use as the requesting IP. Useful for testing
+                ratelimiting.
+
         Returns:
-            Tuple[synapse.http.site.SynapseRequest, channel]
+            The FakeChannel object which stores the result of the request.
         """
         return make_request(
             self.reactor,
@@ -438,6 +429,8 @@ class HomeserverTestCase(TestCase):
             federation_auth_origin,
             content_is_form,
             await_result,
+            custom_headers,
+            client_ip,
         )
 
     def setup_test_homeserver(self, *args, **kwargs):
@@ -554,7 +547,7 @@ class HomeserverTestCase(TestCase):
         self.hs.config.registration_shared_secret = "shared"
 
         # Create the user
-        request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
+        channel = self.make_request("GET", "/_synapse/admin/v1/register")
         self.assertEqual(channel.code, 200, msg=channel.result)
         nonce = channel.json_body["nonce"]
 
@@ -579,7 +572,7 @@ class HomeserverTestCase(TestCase):
                 "inhibit_login": True,
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_synapse/admin/v1/register", body.encode("utf8")
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -597,7 +590,7 @@ class HomeserverTestCase(TestCase):
         if device_id:
             body["device_id"] = device_id
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -665,7 +658,7 @@ class HomeserverTestCase(TestCase):
         """
         body = {"type": "m.login.password", "user": username, "password": password}
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
         )
         self.assertEqual(channel.code, 403, channel.result)
@@ -691,13 +684,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
     A federating homeserver that authenticates incoming requests as `other.example.com`.
     """
 
-    def prepare(self, reactor, clock, homeserver):
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+        return d
+
+
+class TestTransportLayerServer(JsonResource):
+    """A test implementation of TransportLayerServer
+
+    authenticates incoming requests as `other.example.com`.
+    """
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
         class Authenticator:
             def authenticate_request(self, request, content):
                 return succeed("other.example.com")
 
+        authenticator = Authenticator()
+
         ratelimiter = FederationRateLimiter(
-            clock,
+            hs.get_clock(),
             FederationRateLimitConfig(
                 window_size=1,
                 sleep_limit=1,
@@ -706,11 +715,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
                 concurrent_requests=1000,
             ),
         )
-        federation_server.register_servlets(
-            homeserver, self.resource, Authenticator(), ratelimiter
-        )
 
-        return super().prepare(reactor, clock, homeserver)
+        federation_server.register_servlets(hs, self, authenticator, ratelimiter)
 
 
 def override_config(extra_config):
@@ -735,3 +741,29 @@ def override_config(extra_config):
         return func
 
     return decorator
+
+
+TV = TypeVar("TV")
+
+
+def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
+    """A test decorator which will skip the decorated test unless a condition is set
+
+    For example:
+
+    class MyTestCase(TestCase):
+        @skip_unless(HAS_FOO, "Cannot test without foo")
+        def test_foo(self):
+            ...
+
+    Args:
+        condition: If true, the test will be skipped
+        reason: the reason to give for skipping the test
+    """
+
+    def decorator(f: TV) -> TV:
+        if not condition:
+            f.skip = reason  # type: ignore
+        return f
+
+    return decorator
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index dadfabd46d..ecd9efc4df 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -25,13 +25,8 @@ from tests.unittest import TestCase
 class DeferredCacheTestCase(TestCase):
     def test_empty(self):
         cache = DeferredCache("test")
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get("foo")
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_hit(self):
         cache = DeferredCache("test")
@@ -155,13 +150,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(("foo",), 123)
         cache.invalidate(("foo",))
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(("foo",))
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
     def test_invalidate_all(self):
         cache = DeferredCache("testcache")
@@ -215,13 +205,8 @@ class DeferredCacheTestCase(TestCase):
         cache.prefill(2, "two")
         cache.prefill(3, "three")  # 1 will be evicted
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(1)
-        except KeyError:
-            failed = True
-
-        self.assertTrue(failed)
 
         cache.get(2)
         cache.get(3)
@@ -239,13 +224,55 @@ class DeferredCacheTestCase(TestCase):
 
         cache.prefill(3, "three")
 
-        failed = False
-        try:
+        with self.assertRaises(KeyError):
             cache.get(2)
-        except KeyError:
-            failed = True
 
-        self.assertTrue(failed)
+        cache.get(1)
+        cache.get(3)
+
+    def test_eviction_iterable(self):
+        cache = DeferredCache(
+            "test", max_entries=3, apply_cache_factor_from_config=False, iterable=True,
+        )
+
+        cache.prefill(1, ["one", "two"])
+        cache.prefill(2, ["three"])
 
+        # Now access 1 again, thus causing 2 to be least-recently used
+        cache.get(1)
+
+        # Now add an item to the cache, which evicts 2.
+        cache.prefill(3, ["four"])
+        with self.assertRaises(KeyError):
+            cache.get(2)
+
+        # Ensure 1 & 3 are in the cache.
         cache.get(1)
         cache.get(3)
+
+        # Now access 1 again, thus causing 3 to be least-recently used
+        cache.get(1)
+
+        # Now add an item with multiple elements to the cache
+        cache.prefill(4, ["five", "six"])
+
+        # Both 1 and 3 are evicted since there's too many elements.
+        with self.assertRaises(KeyError):
+            cache.get(1)
+        with self.assertRaises(KeyError):
+            cache.get(3)
+
+        # Now add another item to fill the cache again.
+        cache.prefill(5, ["seven"])
+
+        # Now access 4, thus causing 5 to be least-recently used
+        cache.get(4)
+
+        # Add an empty item.
+        cache.prefill(6, [])
+
+        # 5 gets evicted and replaced since an empty element counts as an item.
+        with self.assertRaises(KeyError):
+            cache.get(5)
+        cache.get(4)
+        cache.get(6)
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1ef0af8e8f 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
 
 from tests.unittest import TestCase
 
@@ -45,3 +47,60 @@ class ChunkSeqTests(TestCase):
         self.assertEqual(
             list(parts), [],
         )
+
+
+class SortTopologically(TestCase):
+    def test_empty(self):
+        "Test that an empty graph works correctly"
+
+        graph = {}  # type: Dict[int, List[int]]
+        self.assertEqual(list(sorted_topologically([], graph)), [])
+
+    def test_handle_empty_graph(self):
+        "Test that a graph where a node doesn't have an entry is treated as empty"
+
+        graph = {}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+    def test_disconnected(self):
+        "Test that a graph with no edges work"
+
+        graph = {1: [], 2: []}  # type: Dict[int, List[int]]
+
+        # For disconnected nodes the output is simply sorted.
+        self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+    def test_linear(self):
+        "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_subset(self):
+        "Test that only sorting a subset of the graph works"
+        graph = {1: [], 2: [1], 3: [2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+    def test_fork(self):
+        "Test that a forked graph works"
+        graph = {1: [], 2: [1], 3: [1], 4: [2, 3]}  # type: Dict[int, List[int]]
+
+        # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+        # always get the same one.
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_duplicates(self):
+        "Test that a graph with duplicate edges work"
+        graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+    def test_multiple_paths(self):
+        "Test that a graph with multiple paths between two nodes work"
+        graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}  # type: Dict[int, List[int]]
+
+        self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
diff --git a/tests/utils.py b/tests/utils.py
index c8d3ffbaba..840b657f82 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,13 +20,12 @@ import os
 import time
 import uuid
 import warnings
-from inspect import getcallargs
 from typing import Type
 from urllib import parse as urlparse
 
 from mock import Mock, patch
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
@@ -34,15 +33,12 @@ from synapse.api.room_versions import RoomVersions
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.federation.transport import server as federation_server
-from synapse.http.server import HttpServer
 from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
-from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
 #
@@ -161,6 +157,7 @@ def default_config(name, parse=False):
             "local": {"per_second": 10000, "burst_count": 10000},
             "remote": {"per_second": 10000, "burst_count": 10000},
         },
+        "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
         "saml2_enabled": False,
         "public_baseurl": None,
         "default_identity_server": None,
@@ -342,32 +339,9 @@ def setup_test_homeserver(
 
     hs.get_auth_handler().validate_hash = validate_hash
 
-    fed = kwargs.get("resource_for_federation", None)
-    if fed:
-        register_federation_servlets(hs, fed)
-
     return hs
 
 
-def register_federation_servlets(hs, resource):
-    federation_server.register_servlets(
-        hs,
-        resource=resource,
-        authenticator=federation_server.Authenticator(hs),
-        ratelimiter=FederationRateLimiter(
-            hs.get_clock(), config=hs.config.rc_federation
-        ),
-    )
-
-
-def get_mock_call_args(pattern_func, mock_func):
-    """ Return the arguments the mock function was called with interpreted
-    by the pattern functions argument list.
-    """
-    invoked_args, invoked_kargs = mock_func.call_args
-    return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
-
-
 def mock_getRawHeaders(headers=None):
     headers = headers if headers is not None else {}
 
@@ -378,7 +352,7 @@ def mock_getRawHeaders(headers=None):
 
 
 # This is a mock /resource/ not an entire server
-class MockHttpResource(HttpServer):
+class MockHttpResource:
     def __init__(self, prefix=""):
         self.callbacks = []  # 3-tuple of method/pattern/function
         self.prefix = prefix
@@ -553,86 +527,6 @@ class MockClock:
         return d
 
 
-def _format_call(args, kwargs):
-    return ", ".join(
-        ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
-    )
-
-
-class DeferredMockCallable:
-    """A callable instance that stores a set of pending call expectations and
-    return values for them. It allows a unit test to assert that the given set
-    of function calls are eventually made, by awaiting on them to be called.
-    """
-
-    def __init__(self):
-        self.expectations = []
-        self.calls = []
-
-    def __call__(self, *args, **kwargs):
-        self.calls.append((args, kwargs))
-
-        if not self.expectations:
-            raise ValueError(
-                "%r has no pending calls to handle call(%s)"
-                % (self, _format_call(args, kwargs))
-            )
-
-        for (call, result, d) in self.expectations:
-            if args == call[1] and kwargs == call[2]:
-                d.callback(None)
-                return result
-
-        failure = AssertionError(
-            "Was not expecting call(%s)" % (_format_call(args, kwargs))
-        )
-
-        for _, _, d in self.expectations:
-            try:
-                d.errback(failure)
-            except Exception:
-                pass
-
-        raise failure
-
-    def expect_call_and_return(self, call, result):
-        self.expectations.append((call, result, defer.Deferred()))
-
-    @defer.inlineCallbacks
-    def await_calls(self, timeout=1000):
-        deferred = defer.DeferredList(
-            [d for _, _, d in self.expectations], fireOnOneErrback=True
-        )
-
-        timer = reactor.callLater(
-            timeout / 1000,
-            deferred.errback,
-            AssertionError(
-                "%d pending calls left: %s"
-                % (
-                    len([e for e in self.expectations if not e[2].called]),
-                    [e for e in self.expectations if not e[2].called],
-                )
-            ),
-        )
-
-        yield deferred
-
-        timer.cancel()
-
-        self.calls = []
-
-    def assert_had_no_calls(self):
-        if self.calls:
-            calls = self.calls
-            self.calls = []
-
-            raise AssertionError(
-                "Expected not to received any calls, got:\n"
-                + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
-            )
-
-
 async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room
     """
diff --git a/tox.ini b/tox.ini
index c232676826..9ff70fe312 100644
--- a/tox.ini
+++ b/tox.ini
@@ -2,12 +2,13 @@
 envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort
 
 [base]
-extras = test
 deps =
     python-subunit
     junitxml
     coverage
-    coverage-enable-subprocess
+
+    # this is pinned since it's a bit of an obscure package.
+    coverage-enable-subprocess==1.0
 
     # cyptography 2.2 requires setuptools >= 18.5
     #
@@ -17,35 +18,65 @@ deps =
     # installed on that).
     #
     # anyway, make sure that we have a recent enough setuptools.
-    setuptools>=18.5
+    setuptools>=18.5 ; python_version >= '3.6'
+    setuptools>=18.5,<51.0.0 ; python_version < '3.6'
 
     # we also need a semi-recent version of pip, because old ones fail to
     # install the "enum34" dependency of cryptography.
-    pip>=10
-
-setenv =
-    PYTHONDONTWRITEBYTECODE = no_byte_code
-    COVERAGE_PROCESS_START = {toxinidir}/.coveragerc
-
+    pip>=10 ; python_version >= '3.6'
+    pip>=10,<21.0 ; python_version < '3.6'
+
+# directories/files we run the linters on.
+# if you update this list, make sure to do the same in scripts-dev/lint.sh
+lint_targets =
+    setup.py
+    synapse
+    tests
+    scripts
+    scripts-dev
+    stubs
+    contrib
+    synctl
+    synmark
+    .buildkite
+    docker
+
+# default settings for all tox environments
 [testenv]
 deps =
     {[base]deps}
-extras = all, test
-
-whitelist_externals =
-    sh
+extras =
+    # install the optional dependendencies for tox environments without
+    # '-noextras' in their name
+    !noextras: all
+    test
 
 setenv =
-    {[base]setenv}
+    # use a postgres db for tox environments with "-postgres" in the name
+    # (see https://tox.readthedocs.io/en/3.20.1/config.html#factors-and-factor-conditional-settings)
     postgres: SYNAPSE_POSTGRES = 1
+
+    # this is used by .coveragerc to refer to the top of our tree.
     TOP={toxinidir}
 
 passenv = *
 
 commands =
-    /usr/bin/find "{toxinidir}" -name '*.pyc' -delete
-    # Add this so that coverage will run on subprocesses
-    {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
+    # the "env" invocation enables coverage checking for sub-processes. This is
+    # particularly important when running trial with `-j`, since that will make
+    # it run tests in a subprocess, whose coverage would otherwise not be
+    # tracked.  (It also makes an explicit `coverage run` command redundant.)
+    #
+    # (See https://coverage.readthedocs.io/en/coverage-5.3/subprocess.html.
+    # Note that the `coverage.process_startup()` call is done by
+    # `coverage-enable-subprocess`.)
+    #
+    # we use "env" rather than putting a value in `setenv` so that it is not
+    # inherited by other tox environments.
+    #
+    # keep this in sync with the copy in `testenv:py35-old`.
+    #
+    /usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
 
 # As of twisted 16.4, trial tries to import the tests as a package (previously
 # it loaded the files explicitly), which means they need to be on the
@@ -77,22 +108,23 @@ skip_install=True
 deps =
     # Old automat version for Twisted
     Automat == 0.3.0
-
     lxml
-    coverage
-    coverage-enable-subprocess
+    {[base]deps}
 
 commands =
-    /usr/bin/find "{toxinidir}" -name '*.pyc' -delete
     # Make all greater-thans equals so we test the oldest version of our direct
     # dependencies, but make the pyopenssl 17.0, which can work against an
     # OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis).
-    /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
+    /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
 
     # Install Synapse itself. This won't update any libraries.
     pip install -e ".[test]"
 
-    {envbindir}/coverage run "{envbindir}/trial"  {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
+    # we have to duplicate the command from `testenv` rather than refer to it
+    # as `{[testenv]commands}`, because we run on ubuntu xenial, which has
+    # tox 2.3.1, and https://github.com/tox-dev/tox/issues/208.
+    #
+    /usr/bin/env COVERAGE_PROCESS_START={toxinidir}/.coveragerc "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
 
 [testenv:benchmark]
 deps =
@@ -113,13 +145,13 @@ commands =
 [testenv:check_codestyle]
 extras = lint
 commands =
-    python -m black --check --diff .
-    /bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}"
+    python -m black --check --diff {[base]lint_targets}
+    flake8 {[base]lint_targets} {env:PEP8SUFFIX:}
     {toxinidir}/scripts-dev/config-lint.sh
 
 [testenv:check_isort]
 extras = lint
-commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts"
+commands = isort -c --df --sp setup.cfg {[base]lint_targets}
 
 [testenv:check-newsfragment]
 skip_install = True
@@ -134,6 +166,8 @@ commands = {toxinidir}/scripts-dev/generate_sample_config --check
 skip_install = True
 deps =
     coverage
+    pip>=10 ; python_version >= '3.6'
+    pip>=10,<21.0 ; python_version < '3.6'
 commands=
     coverage combine
     coverage report
@@ -157,7 +191,3 @@ deps =
     {[base]deps}
 extras = all,mypy
 commands = mypy
-
-# To find all folders that pass mypy you run:
-#
-#   find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null"  \; -print