summary refs log tree commit diff
diff options
context:
space:
mode:
authorOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-03-02 11:18:09 +0000
committerOlivier Wilkinson (reivilibre) <oliverw@matrix.org>2022-03-02 11:18:09 +0000
commit20e10495f612145e677e7239cd26a342072f3a71 (patch)
treed855c32429ffff746ae6ea3f7e898b2c31025195
parentmatrix.org hotfixes: Back out in-flight state cache changes (#12117) (diff)
parentReword changelog (diff)
downloadsynapse-20e10495f612145e677e7239cd26a342072f3a71.tar.xz
Merge commit 'd8001' (pre v1.54.0rc1) into matrix-org-hotfixes
-rw-r--r--.flake811
-rw-r--r--.github/workflows/tests.yml13
-rw-r--r--CHANGES.md100
-rw-r--r--MANIFEST.in1
-rw-r--r--changelog.d/11835.feature1
-rw-r--r--changelog.d/11972.misc1
-rw-r--r--changelog.d/11974.misc1
-rw-r--r--changelog.d/11984.misc1
-rw-r--r--changelog.d/11985.feature1
-rw-r--r--changelog.d/11991.misc1
-rw-r--r--changelog.d/11992.bugfix1
-rw-r--r--changelog.d/11994.misc1
-rw-r--r--changelog.d/11996.misc1
-rw-r--r--changelog.d/11997.docker1
-rw-r--r--changelog.d/11999.bugfix1
-rw-r--r--changelog.d/12000.feature1
-rw-r--r--changelog.d/12003.doc1
-rw-r--r--changelog.d/12004.doc1
-rw-r--r--changelog.d/12005.misc1
-rw-r--r--changelog.d/12008.removal1
-rw-r--r--changelog.d/12009.feature1
-rw-r--r--changelog.d/12011.misc1
-rw-r--r--changelog.d/12013.misc1
-rw-r--r--changelog.d/12015.misc1
-rw-r--r--changelog.d/12016.misc1
-rw-r--r--changelog.d/12018.removal1
-rw-r--r--changelog.d/12019.misc1
-rw-r--r--changelog.d/12020.feature1
-rw-r--r--changelog.d/12021.feature1
-rw-r--r--changelog.d/12022.feature1
-rw-r--r--changelog.d/12024.bugfix1
-rw-r--r--changelog.d/12025.misc1
-rw-r--r--changelog.d/12030.misc1
-rw-r--r--changelog.d/12033.misc1
-rw-r--r--changelog.d/12034.misc1
-rw-r--r--changelog.d/12039.misc1
-rw-r--r--changelog.d/12041.misc1
-rw-r--r--changelog.d/12051.misc1
-rw-r--r--changelog.d/12052.misc1
-rw-r--r--changelog.d/12056.bugfix1
-rw-r--r--changelog.d/12058.feature1
-rw-r--r--debian/changelog6
-rw-r--r--docker/Dockerfile4
-rw-r--r--docs/manhole.md2
-rw-r--r--docs/modules/third_party_rules_callbacks.md56
-rw-r--r--docs/workers.md92
-rw-r--r--mypy.ini6
-rwxr-xr-xscripts-dev/check-newsfragment2
-rwxr-xr-xscripts-dev/complement.sh4
-rwxr-xr-xscripts/update_synapse_database2
-rw-r--r--setup.cfg12
-rwxr-xr-xsetup.py1
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py2
-rw-r--r--synapse/api/auth_blocking.py2
-rw-r--r--synapse/api/filtering.py9
-rw-r--r--synapse/app/__init__.py6
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/app/homeserver.py4
-rw-r--r--synapse/app/phone_stats_home.py14
-rw-r--r--synapse/appservice/__init__.py16
-rw-r--r--synapse/appservice/api.py20
-rw-r--r--synapse/appservice/scheduler.py100
-rw-r--r--synapse/config/appservice.py13
-rw-r--r--synapse/config/cache.py2
-rw-r--r--synapse/config/experimental.py19
-rw-r--r--synapse/config/metrics.py2
-rw-r--r--synapse/config/oidc.py2
-rw-r--r--synapse/config/redis.py2
-rw-r--r--synapse/config/repository.py2
-rw-r--r--synapse/config/saml2.py2
-rw-r--r--synapse/config/tracer.py2
-rw-r--r--synapse/crypto/keyring.py4
-rw-r--r--synapse/event_auth.py6
-rw-r--r--synapse/events/builder.py2
-rw-r--r--synapse/events/snapshot.py9
-rw-r--r--synapse/events/third_party_rules.py58
-rw-r--r--synapse/federation/federation_base.py2
-rw-r--r--synapse/federation/federation_client.py308
-rw-r--r--synapse/federation/sender/__init__.py2
-rw-r--r--synapse/federation/sender/per_destination_queue.py9
-rw-r--r--synapse/federation/sender/transaction_manager.py2
-rw-r--r--synapse/federation/transport/client.py52
-rw-r--r--synapse/federation/transport/server/__init__.py8
-rw-r--r--synapse/federation/transport/server/_base.py2
-rw-r--r--synapse/federation/transport/server/federation.py115
-rw-r--r--synapse/groups/attestations.py2
-rw-r--r--synapse/groups/groups_server.py2
-rw-r--r--synapse/handlers/account.py144
-rw-r--r--synapse/handlers/account_data.py4
-rw-r--r--synapse/handlers/account_validity.py2
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/cas.py2
-rw-r--r--synapse/handlers/deactivate_account.py22
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/directory.py2
-rw-r--r--synapse/handlers/e2e_keys.py4
-rw-r--r--synapse/handlers/e2e_room_keys.py2
-rw-r--r--synapse/handlers/event_auth.py2
-rw-r--r--synapse/handlers/events.py4
-rw-r--r--synapse/handlers/federation.py13
-rw-r--r--synapse/handlers/federation_event.py15
-rw-r--r--synapse/handlers/groups_local.py2
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/oidc.py2
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/handlers/presence.py4
-rw-r--r--synapse/handlers/profile.py16
-rw-r--r--synapse/handlers/read_marker.py2
-rw-r--r--synapse/handlers/receipts.py4
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/handlers/room_batch.py2
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/room_summary.py325
-rw-r--r--synapse/handlers/saml.py2
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/set_password.py2
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/state_deltas.py2
-rw-r--r--synapse/handlers/stats.py2
-rw-r--r--synapse/handlers/sync.py11
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--synapse/handlers/ui_auth/checkers.py4
-rw-r--r--synapse/handlers/user_directory.py2
-rw-r--r--synapse/http/matrixfederationclient.py2
-rw-r--r--synapse/module_api/__init__.py9
-rw-r--r--synapse/notifier.py2
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/push/emailpusher.py2
-rw-r--r--synapse/push/httppusher.py4
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/push/push_rule_evaluator.py8
-rw-r--r--synapse/push/pusherpool.py2
-rw-r--r--synapse/python_dependencies.py107
-rw-r--r--synapse/replication/http/_base.py4
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/federation.py10
-rw-r--r--synapse/replication/http/membership.py10
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/http/send_event.py2
-rw-r--r--synapse/replication/tcp/client.py4
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py24
-rw-r--r--synapse/replication/tcp/streams/events.py2
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/background_updates.py2
-rw-r--r--synapse/rest/admin/devices.py6
-rw-r--r--synapse/rest/admin/event_reports.py4
-rw-r--r--synapse/rest/admin/federation.py8
-rw-r--r--synapse/rest/admin/media.py20
-rw-r--r--synapse/rest/admin/registration_tokens.py6
-rw-r--r--synapse/rest/admin/rooms.py16
-rw-r--r--synapse/rest/admin/statistics.py2
-rw-r--r--synapse/rest/admin/users.py24
-rw-r--r--synapse/rest/client/account.py48
-rw-r--r--synapse/rest/client/account_data.py4
-rw-r--r--synapse/rest/client/capabilities.py6
-rw-r--r--synapse/rest/client/directory.py6
-rw-r--r--synapse/rest/client/events.py2
-rw-r--r--synapse/rest/client/groups.py8
-rw-r--r--synapse/rest/client/initial_sync.py2
-rw-r--r--synapse/rest/client/keys.py2
-rw-r--r--synapse/rest/client/login.py4
-rw-r--r--synapse/rest/client/notifications.py2
-rw-r--r--synapse/rest/client/openid.py2
-rw-r--r--synapse/rest/client/push_rule.py2
-rw-r--r--synapse/rest/client/pusher.py4
-rw-r--r--synapse/rest/client/register.py10
-rw-r--r--synapse/rest/client/relations.py6
-rw-r--r--synapse/rest/client/report_event.py2
-rw-r--r--synapse/rest/client/room.py80
-rw-r--r--synapse/rest/client/room_batch.py2
-rw-r--r--synapse/rest/client/shared_rooms.py2
-rw-r--r--synapse/rest/client/sync.py2
-rw-r--r--synapse/rest/client/tags.py2
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/rest/consent/consent_resource.py2
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py2
-rw-r--r--synapse/rest/media/v1/media_repository.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py2
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py2
-rw-r--r--synapse/rest/media/v1/upload_resource.py2
-rw-r--r--synapse/rest/synapse/client/password_reset.py2
-rw-r--r--synapse/server.py21
-rw-r--r--synapse/server_notices/consent_server_notices.py2
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py2
-rw-r--r--synapse/server_notices/server_notices_manager.py2
-rw-r--r--synapse/state/__init__.py33
-rw-r--r--synapse/storage/databases/__init__.py3
-rw-r--r--synapse/storage/databases/main/appservice.py33
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py112
-rw-r--r--synapse/storage/databases/main/events.py66
-rw-r--r--synapse/storage/databases/main/events_worker.py28
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py2
-rw-r--r--synapse/storage/databases/main/room.py37
-rw-r--r--synapse/storage/databases/main/search.py26
-rw-r--r--synapse/storage/databases/main/state.py27
-rw-r--r--synapse/storage/persist_events.py25
-rw-r--r--synapse/storage/schema/main/delta/68/04partial_state_rooms.sql41
-rw-r--r--synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite22
-rw-r--r--synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py72
-rw-r--r--synapse/streams/events.py2
-rw-r--r--synapse/util/__init__.py4
-rw-r--r--synapse/util/async_helpers.py78
-rw-r--r--synapse/util/caches/descriptors.py64
-rw-r--r--synapse/util/check_dependencies.py127
-rwxr-xr-xsynctl8
-rw-r--r--tests/api/test_auth.py44
-rw-r--r--tests/api/test_filtering.py32
-rw-r--r--tests/api/test_ratelimiting.py46
-rw-r--r--tests/app/test_phone_stats_home.py34
-rw-r--r--tests/appservice/test_scheduler.py117
-rw-r--r--tests/crypto/test_event_signing.py8
-rw-r--r--tests/crypto/test_keyring.py16
-rw-r--r--tests/events/test_snapshot.py2
-rw-r--r--tests/events/test_utils.py16
-rw-r--r--tests/federation/test_complexity.py8
-rw-r--r--tests/federation/test_federation_catch_up.py26
-rw-r--r--tests/federation/test_federation_sender.py6
-rw-r--r--tests/federation/test_federation_server.py12
-rw-r--r--tests/federation/transport/test_knocking.py18
-rw-r--r--tests/federation/transport/test_server.py24
-rw-r--r--tests/handlers/test_appservice.py222
-rw-r--r--tests/handlers/test_auth.py12
-rw-r--r--tests/handlers/test_cas.py2
-rw-r--r--tests/handlers/test_deactivate_account.py2
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_directory.py34
-rw-r--r--tests/handlers/test_e2e_keys.py2
-rw-r--r--tests/handlers/test_federation.py2
-rw-r--r--tests/handlers/test_message.py2
-rw-r--r--tests/handlers/test_oidc.py6
-rw-r--r--tests/handlers/test_presence.py82
-rw-r--r--tests/handlers/test_profile.py24
-rw-r--r--tests/handlers/test_receipts.py2
-rw-r--r--tests/handlers/test_register.py10
-rw-r--r--tests/handlers/test_room_summary.py124
-rw-r--r--tests/handlers/test_saml.py4
-rw-r--r--tests/handlers/test_stats.py2
-rw-r--r--tests/handlers/test_sync.py10
-rw-r--r--tests/handlers/test_typing.py42
-rw-r--r--tests/handlers/test_user_directory.py6
-rw-r--r--tests/http/federation/test_srv_resolver.py24
-rw-r--r--tests/module_api/test_api.py2
-rw-r--r--tests/push/test_email.py22
-rw-r--r--tests/push/test_http.py22
-rw-r--r--tests/push/test_push_rule_evaluator.py9
-rw-r--r--tests/replication/_base.py6
-rw-r--r--tests/replication/slave/storage/_base.py4
-rw-r--r--tests/replication/slave/storage/test_events.py2
-rw-r--r--tests/replication/tcp/streams/test_account_data.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py4
-rw-r--r--tests/replication/tcp/streams/test_receipts.py4
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/replication/test_sharded_event_persister.py8
-rw-r--r--tests/rest/admin/test_background_updates.py2
-rw-r--r--tests/rest/admin/test_federation.py4
-rw-r--r--tests/rest/admin/test_media.py4
-rw-r--r--tests/rest/admin/test_registration_tokens.py2
-rw-r--r--tests/rest/admin/test_room.py20
-rw-r--r--tests/rest/admin/test_server_notice.py2
-rw-r--r--tests/rest/admin/test_user.py20
-rw-r--r--tests/rest/client/test_account.py238
-rw-r--r--tests/rest/client/test_auth.py70
-rw-r--r--tests/rest/client/test_capabilities.py30
-rw-r--r--tests/rest/client/test_consent.py19
-rw-r--r--tests/rest/client/test_device_lists.py16
-rw-r--r--tests/rest/client/test_ephemeral_message.py19
-rw-r--r--tests/rest/client/test_events.py30
-rw-r--r--tests/rest/client/test_filter.py12
-rw-r--r--tests/rest/client/test_groups.py14
-rw-r--r--tests/rest/client/test_identity.py13
-rw-r--r--tests/rest/client/test_keys.py6
-rw-r--r--tests/rest/client/test_login.py202
-rw-r--r--tests/rest/client/test_password_policy.py39
-rw-r--r--tests/rest/client/test_power_levels.py47
-rw-r--r--tests/rest/client/test_presence.py15
-rw-r--r--tests/rest/client/test_profile.py82
-rw-r--r--tests/rest/client/test_push_rule_attrs.py26
-rw-r--r--tests/rest/client/test_redactions.py26
-rw-r--r--tests/rest/client/test_register.py304
-rw-r--r--tests/rest/client/test_relations.py310
-rw-r--r--tests/rest/client/test_retention.py45
-rw-r--r--tests/rest/client/test_room_batch.py2
-rw-r--r--tests/rest/client/test_rooms.py202
-rw-r--r--tests/rest/client/test_sendtodevice.py8
-rw-r--r--tests/rest/client/test_shadow_banned.py44
-rw-r--r--tests/rest/client/test_shared_rooms.py46
-rw-r--r--tests/rest/client/test_sync.py65
-rw-r--r--tests/rest/client/test_third_party_rules.py233
-rw-r--r--tests/rest/client/test_typing.py20
-rw-r--r--tests/rest/client/test_upgrade_room.py37
-rw-r--r--tests/rest/client/utils.py115
-rw-r--r--tests/rest/media/v1/test_media_storage.py4
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py2
-rw-r--r--tests/storage/databases/main/test_deviceinbox.py2
-rw-r--r--tests/storage/databases/main/test_events_worker.py18
-rw-r--r--tests/storage/databases/main/test_lock.py2
-rw-r--r--tests/storage/databases/main/test_room.py2
-rw-r--r--tests/storage/test__base.py2
-rw-r--r--tests/storage/test_account_data.py2
-rw-r--r--tests/storage/test_appservice.py96
-rw-r--r--tests/storage/test_background_update.py8
-rw-r--r--tests/storage/test_base.py6
-rw-r--r--tests/storage/test_cleanup_extrems.py4
-rw-r--r--tests/storage/test_client_ips.py4
-rw-r--r--tests/storage/test_devices.py2
-rw-r--r--tests/storage/test_directory.py4
-rw-r--r--tests/storage/test_e2e_room_keys.py2
-rw-r--r--tests/storage/test_end_to_end_keys.py2
-rw-r--r--tests/storage/test_event_chain.py4
-rw-r--r--tests/storage/test_event_federation.py2
-rw-r--r--tests/storage/test_event_push_actions.py4
-rw-r--r--tests/storage/test_events.py4
-rw-r--r--tests/storage/test_id_generators.py6
-rw-r--r--tests/storage/test_keys.py4
-rw-r--r--tests/storage/test_main.py10
-rw-r--r--tests/storage/test_monthly_active_users.py2
-rw-r--r--tests/storage/test_profile.py6
-rw-r--r--tests/storage/test_purge.py4
-rw-r--r--tests/storage/test_redaction.py2
-rw-r--r--tests/storage/test_registration.py8
-rw-r--r--tests/storage/test_rollback_worker.py6
-rw-r--r--tests/storage/test_room.py8
-rw-r--r--tests/storage/test_room_search.py125
-rw-r--r--tests/storage/test_roommember.py6
-rw-r--r--tests/storage/test_state.py2
-rw-r--r--tests/storage/test_stream.py2
-rw-r--r--tests/storage/test_transactions.py2
-rw-r--r--tests/storage/test_user_directory.py4
-rw-r--r--tests/test_distributor.py2
-rw-r--r--tests/test_federation.py8
-rw-r--r--tests/test_mau.py2
-rw-r--r--tests/test_state.py61
-rw-r--r--tests/test_terms_auth.py6
-rw-r--r--tests/test_test_utils.py2
-rw-r--r--tests/test_types.py16
-rw-r--r--tests/test_utils/event_injection.py4
-rw-r--r--tests/test_visibility.py4
-rw-r--r--tests/unittest.py20
-rw-r--r--tests/util/caches/test_deferred_cache.py2
-rw-r--r--tests/util/caches/test_descriptors.py80
-rw-r--r--tests/util/test_async_helpers.py160
-rw-r--r--tests/util/test_check_dependencies.py95
-rw-r--r--tests/util/test_expiring_cache.py40
-rw-r--r--tests/util/test_logcontext.py2
-rw-r--r--tests/util/test_lrucache.py140
-rw-r--r--tests/util/test_retryutils.py4
-rw-r--r--tests/util/test_rwlock.py30
-rw-r--r--tests/util/test_treecache.py48
-rw-r--r--tests/utils.py2
-rw-r--r--tox.ini10
363 files changed, 4750 insertions, 3028 deletions
diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000000..acb118c86e
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,11 @@
+# TODO: incorporate this into pyproject.toml if flake8 supports it in the future.
+# See https://github.com/PyCQA/flake8/issues/234
+[flake8]
+# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes
+# for error codes. The ones we ignore are:
+#  W503: line break before binary operator
+#  W504: line break after binary operator
+#  E203: whitespace before ':' (which is contrary to pep8?)
+#  E731: do not assign a lambda expression, use a def
+#  E501: Line too long (black enforces this for us)
+ignore=W503,W504,E203,E731,E501
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index bbf1033bdd..e9e4277322 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -10,12 +10,19 @@ concurrency:
   cancel-in-progress: true
 
 jobs:
+  check-sampleconfig:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-python@v2
+      - run: pip install -e .
+      - run: scripts-dev/generate_sample_config --check
+
   lint:
     runs-on: ubuntu-latest
     strategy:
       matrix:
         toxenv:
-          - "check-sampleconfig"
           - "check_codestyle"
           - "check_isort"
           - "mypy"
@@ -43,7 +50,7 @@ jobs:
           ref: ${{ github.event.pull_request.head.sha }}
           fetch-depth: 0
       - uses: actions/setup-python@v2
-      - run: pip install tox
+      - run: "pip install 'towncrier>=18.6.0rc1'"
       - run: scripts-dev/check-newsfragment
         env:
           PULL_REQUEST_NUMBER: ${{ github.event.number }}
@@ -51,7 +58,7 @@ jobs:
   # Dummy step to gate other tests on without repeating the whole list
   linting-done:
     if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
-    needs: [lint, lint-crlf, lint-newsfile]
+    needs: [lint, lint-crlf, lint-newsfile, check-sampleconfig]
     runs-on: ubuntu-latest
     steps:
       - run: "true"
diff --git a/CHANGES.md b/CHANGES.md
index 81333097ae..5485e8d47e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,3 +1,103 @@
+Synapse 1.54.0rc1 (2022-03-02)
+==============================
+
+Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier.
+Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later.
+
+
+Features
+--------
+
+- Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617))
+- Make a `POST` to `/rooms/<room_id>/receipt/m.read/<event_id>` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. ([\#11835](https://github.com/matrix-org/synapse/issues/11835))
+- Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985))
+- Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000))
+- Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067))
+- Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009))
+- Advertise Matrix 1.1 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020))
+- Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021))
+- Advertise Matrix 1.2 support on `/_matrix/client/versions`. ([\#12022](https://github.com/matrix-org/synapse/issues/12022))
+- Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058))
+- Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062))
+
+
+Bugfixes
+--------
+
+- Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992))
+- Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999))
+- Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024))
+- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an "argument of type 'int' is not iterable" error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037))
+- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. ([\#12056](https://github.com/matrix-org/synapse/issues/12056))
+- Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077))
+- Fix occasional 'Unhandled error in Deferred' error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089))
+- Fix a bug introduced in Synapse 1.51.0rc1 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098))
+- Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100))
+- Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105))
+
+
+Updates to the Docker image
+---------------------------
+
+- The Docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997))
+- Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112))
+
+
+Improved Documentation
+----------------------
+
+- Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. ([\#11599](https://github.com/matrix-org/synapse/issues/11599))
+- Explain the meaning of spam checker callbacks' return values. ([\#12003](https://github.com/matrix-org/synapse/issues/12003))
+- Clarify information about external Identity Provider IDs. ([\#12004](https://github.com/matrix-org/synapse/issues/12004))
+
+
+Deprecations and Removals
+-------------------------
+
+- Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. ([\#11865](https://github.com/matrix-org/synapse/issues/11865))
+- Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). ([\#12008](https://github.com/matrix-org/synapse/issues/12008))
+- Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. ([\#12018](https://github.com/matrix-org/synapse/issues/12018))
+- Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#12073](https://github.com/matrix-org/synapse/issues/12073))
+
+
+Internal Changes
+----------------
+
+- Make the `get_room_version` method use `get_room_version_id` to benefit from caching. ([\#11808](https://github.com/matrix-org/synapse/issues/11808))
+- Remove unnecessary condition on knock -> leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900))
+- Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972))
+- Optimise calculating `device_list` changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974))
+- Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984))
+- Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991))
+- Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994))
+- Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996))
+- Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039))
+- Preparation for faster-room-join work: parse MSC3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011))
+- Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012))
+- Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013))
+- Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015))
+- Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016))
+- Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019))
+- Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025))
+- Upgrade Mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030))
+- Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070))
+- Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069))
+- After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041))
+- Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051))
+- Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059))
+- Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060))
+- Fix using the `complement.sh` script without specifying a directory or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063))
+- Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094))
+- Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068))
+- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088))
+- Use `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092))
+- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to `/versions`. ([\#12099](https://github.com/matrix-org/synapse/issues/12099))
+- Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106))
+- Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109))
+- Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111))
+- Move CI checks out of tox, to facilitate a move to using poetry. ([\#12119](https://github.com/matrix-org/synapse/issues/12119))
+
+
 Synapse 1.53.0 (2022-02-22)
 ===========================
 
diff --git a/MANIFEST.in b/MANIFEST.in
index c24786c3b3..76d14eb642 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -45,6 +45,7 @@ include book.toml
 include pyproject.toml
 recursive-include changelog.d *
 
+include .flake8
 prune .circleci
 prune .github
 prune .ci
diff --git a/changelog.d/11835.feature b/changelog.d/11835.feature
deleted file mode 100644
index 7cee39b08c..0000000000
--- a/changelog.d/11835.feature
+++ /dev/null
@@ -1 +0,0 @@
-Make a `POST` to `/rooms/<room_id>/receipt/m.read/<event_id>` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push.
diff --git a/changelog.d/11972.misc b/changelog.d/11972.misc
deleted file mode 100644
index 29c38bfd82..0000000000
--- a/changelog.d/11972.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add tests for device list changes between local users.
\ No newline at end of file
diff --git a/changelog.d/11974.misc b/changelog.d/11974.misc
deleted file mode 100644
index 1debad2361..0000000000
--- a/changelog.d/11974.misc
+++ /dev/null
@@ -1 +0,0 @@
-Optimise calculating device_list changes in `/sync`.
diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc
deleted file mode 100644
index 8e405b9226..0000000000
--- a/changelog.d/11984.misc
+++ /dev/null
@@ -1 +0,0 @@
-Add missing type hints to storage classes.
\ No newline at end of file
diff --git a/changelog.d/11985.feature b/changelog.d/11985.feature
deleted file mode 100644
index 120d888a49..0000000000
--- a/changelog.d/11985.feature
+++ /dev/null
@@ -1 +0,0 @@
-Fetch images when previewing Twitter URLs. Contributed by @AndrewRyanChama.
diff --git a/changelog.d/11991.misc b/changelog.d/11991.misc
deleted file mode 100644
index 34a3b3a6b9..0000000000
--- a/changelog.d/11991.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor the search code for improved readability.
diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix
deleted file mode 100644
index f73c86bb25..0000000000
--- a/changelog.d/11992.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary.
diff --git a/changelog.d/11994.misc b/changelog.d/11994.misc
deleted file mode 100644
index d64297dd78..0000000000
--- a/changelog.d/11994.misc
+++ /dev/null
@@ -1 +0,0 @@
-Move common deduplication code down into `_auth_and_persist_outliers`.
diff --git a/changelog.d/11996.misc b/changelog.d/11996.misc
deleted file mode 100644
index 6c675fd193..0000000000
--- a/changelog.d/11996.misc
+++ /dev/null
@@ -1 +0,0 @@
-Limit concurrent joins from applications services.
\ No newline at end of file
diff --git a/changelog.d/11997.docker b/changelog.d/11997.docker
deleted file mode 100644
index 1b3271457e..0000000000
--- a/changelog.d/11997.docker
+++ /dev/null
@@ -1 +0,0 @@
-The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage.
diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix
deleted file mode 100644
index fd84095900..0000000000
--- a/changelog.d/11999.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room.
diff --git a/changelog.d/12000.feature b/changelog.d/12000.feature
deleted file mode 100644
index 246cc87f0b..0000000000
--- a/changelog.d/12000.feature
+++ /dev/null
@@ -1 +0,0 @@
-Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time.
diff --git a/changelog.d/12003.doc b/changelog.d/12003.doc
deleted file mode 100644
index 1ac8163559..0000000000
--- a/changelog.d/12003.doc
+++ /dev/null
@@ -1 +0,0 @@
-Explain the meaning of spam checker callbacks' return values.
diff --git a/changelog.d/12004.doc b/changelog.d/12004.doc
deleted file mode 100644
index 0b4baef210..0000000000
--- a/changelog.d/12004.doc
+++ /dev/null
@@ -1 +0,0 @@
-Clarify information about external Identity Provider IDs.
diff --git a/changelog.d/12005.misc b/changelog.d/12005.misc
deleted file mode 100644
index 45e21dbe59..0000000000
--- a/changelog.d/12005.misc
+++ /dev/null
@@ -1 +0,0 @@
-Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.
diff --git a/changelog.d/12008.removal b/changelog.d/12008.removal
deleted file mode 100644
index 57599d9ee9..0000000000
--- a/changelog.d/12008.removal
+++ /dev/null
@@ -1 +0,0 @@
-Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration).
diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature
deleted file mode 100644
index c8a531481e..0000000000
--- a/changelog.d/12009.feature
+++ /dev/null
@@ -1 +0,0 @@
-Enable modules to set a custom display name when registering a user.
diff --git a/changelog.d/12011.misc b/changelog.d/12011.misc
deleted file mode 100644
index 258b0e389f..0000000000
--- a/changelog.d/12011.misc
+++ /dev/null
@@ -1 +0,0 @@
-Preparation for faster-room-join work: parse msc3706 fields in send_join response.
diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc
deleted file mode 100644
index c0fca8dccb..0000000000
--- a/changelog.d/12013.misc
+++ /dev/null
@@ -1 +0,0 @@
-Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.
diff --git a/changelog.d/12015.misc b/changelog.d/12015.misc
deleted file mode 100644
index 3aa32ab4cf..0000000000
--- a/changelog.d/12015.misc
+++ /dev/null
@@ -1 +0,0 @@
-Configure `tox` to use `venv` rather than `virtualenv`.
diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc
deleted file mode 100644
index 8856ef46a9..0000000000
--- a/changelog.d/12016.misc
+++ /dev/null
@@ -1 +0,0 @@
-Fix bug in `StateFilter.return_expanded()` and add some tests.
\ No newline at end of file
diff --git a/changelog.d/12018.removal b/changelog.d/12018.removal
deleted file mode 100644
index e940b62228..0000000000
--- a/changelog.d/12018.removal
+++ /dev/null
@@ -1 +0,0 @@
-Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported.
diff --git a/changelog.d/12019.misc b/changelog.d/12019.misc
deleted file mode 100644
index b2186320ea..0000000000
--- a/changelog.d/12019.misc
+++ /dev/null
@@ -1 +0,0 @@
-Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms.
\ No newline at end of file
diff --git a/changelog.d/12020.feature b/changelog.d/12020.feature
deleted file mode 100644
index 1ac9d2060e..0000000000
--- a/changelog.d/12020.feature
+++ /dev/null
@@ -1 +0,0 @@
-Advertise Matrix 1.1 support on `/_matrix/client/versions`.
\ No newline at end of file
diff --git a/changelog.d/12021.feature b/changelog.d/12021.feature
deleted file mode 100644
index 01378df8ca..0000000000
--- a/changelog.d/12021.feature
+++ /dev/null
@@ -1 +0,0 @@
-Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`.
\ No newline at end of file
diff --git a/changelog.d/12022.feature b/changelog.d/12022.feature
deleted file mode 100644
index 188fb12570..0000000000
--- a/changelog.d/12022.feature
+++ /dev/null
@@ -1 +0,0 @@
-Advertise Matrix 1.2 support on `/_matrix/client/versions`.
\ No newline at end of file
diff --git a/changelog.d/12024.bugfix b/changelog.d/12024.bugfix
deleted file mode 100644
index 59bcdb93a5..0000000000
--- a/changelog.d/12024.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint.
diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc
deleted file mode 100644
index d9475a7718..0000000000
--- a/changelog.d/12025.misc
+++ /dev/null
@@ -1 +0,0 @@
-Update the `olddeps` CI job to use an old version of `markupsafe`.
diff --git a/changelog.d/12030.misc b/changelog.d/12030.misc
deleted file mode 100644
index 607ee97ce6..0000000000
--- a/changelog.d/12030.misc
+++ /dev/null
@@ -1 +0,0 @@
-Upgrade mypy to version 0.931.
diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc
deleted file mode 100644
index 3af049b969..0000000000
--- a/changelog.d/12033.misc
+++ /dev/null
@@ -1 +0,0 @@
-Deduplicate in-flight requests in `_get_state_for_groups`.
diff --git a/changelog.d/12034.misc b/changelog.d/12034.misc
deleted file mode 100644
index 8374a63220..0000000000
--- a/changelog.d/12034.misc
+++ /dev/null
@@ -1 +0,0 @@
-Minor typing fixes.
diff --git a/changelog.d/12039.misc b/changelog.d/12039.misc
deleted file mode 100644
index 45e21dbe59..0000000000
--- a/changelog.d/12039.misc
+++ /dev/null
@@ -1 +0,0 @@
-Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`.
diff --git a/changelog.d/12041.misc b/changelog.d/12041.misc
deleted file mode 100644
index e56dc093de..0000000000
--- a/changelog.d/12041.misc
+++ /dev/null
@@ -1 +0,0 @@
-After joining a room, create a dedicated logcontext to process the queued events.
diff --git a/changelog.d/12051.misc b/changelog.d/12051.misc
deleted file mode 100644
index 9959191352..0000000000
--- a/changelog.d/12051.misc
+++ /dev/null
@@ -1 +0,0 @@
-Tidy up GitHub Actions config which builds distributions for PyPI.
\ No newline at end of file
diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc
deleted file mode 100644
index fbaff67e95..0000000000
--- a/changelog.d/12052.misc
+++ /dev/null
@@ -1 +0,0 @@
-Move `isort` configuration to `pyproject.toml`.
diff --git a/changelog.d/12056.bugfix b/changelog.d/12056.bugfix
deleted file mode 100644
index 210e30c63f..0000000000
--- a/changelog.d/12056.bugfix
+++ /dev/null
@@ -1 +0,0 @@
-Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens.
\ No newline at end of file
diff --git a/changelog.d/12058.feature b/changelog.d/12058.feature
deleted file mode 100644
index 7b71692229..0000000000
--- a/changelog.d/12058.feature
+++ /dev/null
@@ -1 +0,0 @@
-Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)).
diff --git a/debian/changelog b/debian/changelog
index 574930c085..df3db85b8e 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+matrix-synapse-py3 (1.54.0~rc1) stable; urgency=medium
+
+  * New synapse release 1.54.0~rc1.
+
+ -- Synapse Packaging team <packages@matrix.org>  Wed, 02 Mar 2022 10:43:22 +0000
+
 matrix-synapse-py3 (1.53.0) stable; urgency=medium
 
   * New synapse release 1.53.0.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index e4c1c19b86..a8bb9b0e7f 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -11,10 +11,10 @@
 # There is an optional PYTHON_VERSION build argument which sets the
 # version of python to build against: for example:
 #
-#    DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 .
+#    DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.10 .
 #
 
-ARG PYTHON_VERSION=3.8
+ARG PYTHON_VERSION=3.9
 
 ###
 ### Stage 0: builder
diff --git a/docs/manhole.md b/docs/manhole.md
index 715ed840f2..a82fad0f0f 100644
--- a/docs/manhole.md
+++ b/docs/manhole.md
@@ -94,6 +94,6 @@ As a simple example, retrieving an event from the database:
 
 ```pycon
 >>> from twisted.internet import defer
->>> defer.ensureDeferred(hs.get_datastore().get_event('$1416420717069yeQaw:matrix.org'))
+>>> defer.ensureDeferred(hs.get_datastores().main.get_event('$1416420717069yeQaw:matrix.org'))
 <Deferred at 0x7ff253fc6998 current result: <FrozenEvent event_id='$1416420717069yeQaw:matrix.org', type='m.room.create', state_key=''>>
 ```
diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md
index a3a17096a8..09ac838107 100644
--- a/docs/modules/third_party_rules_callbacks.md
+++ b/docs/modules/third_party_rules_callbacks.md
@@ -148,6 +148,62 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c
 
 If multiple modules implement this callback, Synapse runs them all in order.
 
+### `on_profile_update`
+
+_First introduced in Synapse v1.54.0_
+
+```python
+async def on_profile_update(
+    user_id: str,
+    new_profile: "synapse.module_api.ProfileInfo",
+    by_admin: bool,
+    deactivation: bool,
+) -> None:
+```
+
+Called after updating a local user's profile. The update can be triggered either by the
+user themselves or a server admin. The update can also be triggered by a user being
+deactivated (in which case their display name is set to an empty string (`""`) and the
+avatar URL is set to `None`). The module is passed the Matrix ID of the user whose profile
+has been updated, their new profile, as well as a `by_admin` boolean that is `True` if the
+update was triggered by a server admin (and `False` otherwise), and a `deactivated`
+boolean that is `True` if the update is a result of the user being deactivated.
+
+Note that the `by_admin` boolean is also `True` if the profile change happens as a result
+of the user logging in through Single Sign-On, or if a server admin updates their own
+profile.
+
+Per-room profile changes do not trigger this callback to be called. Synapse administrators
+wishing this callback to be called on every profile change are encouraged to disable
+per-room profiles globally using the `allow_per_room_profiles` configuration setting in
+Synapse's configuration file.
+This callback is not called when registering a user, even when setting it through the
+[`get_displayname_for_registration`](https://matrix-org.github.io/synapse/latest/modules/password_auth_provider_callbacks.html#get_displayname_for_registration)
+module callback.
+
+If multiple modules implement this callback, Synapse runs them all in order.
+
+### `on_user_deactivation_status_changed`
+
+_First introduced in Synapse v1.54.0_
+
+```python
+async def on_user_deactivation_status_changed(
+    user_id: str, deactivated: bool, by_admin: bool
+) -> None:
+```
+
+Called after deactivating a local user, or reactivating them through the admin API. The
+deactivation can be triggered either by the user themselves or a server admin. The module
+is passed the Matrix ID of the user whose status is changed, as well as a `deactivated`
+boolean that is `True` if the user is being deactivated and `False` if they're being
+reactivated, and a `by_admin` boolean that is `True` if the deactivation was triggered by
+a server admin (and `False` otherwise). This latter `by_admin` boolean is always `True`
+if the user is being reactivated, as this operation can only be performed through the
+admin API.
+
+If multiple modules implement this callback, Synapse runs them all in order.
+
 ## Example
 
 The example below is a module that implements the third-party rules callback
diff --git a/docs/workers.md b/docs/workers.md
index dadde4d726..b0f8599ef0 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -178,8 +178,11 @@ recommend the use of `systemd` where available: for information on setting up
 
 ### `synapse.app.generic_worker`
 
-This worker can handle API requests matching the following regular
-expressions:
+This worker can handle API requests matching the following regular expressions.
+These endpoints can be routed to any worker. If a worker is set up to handle a
+stream then, for maximum efficiency, additional endpoints should be routed to that
+worker: refer to the [stream writers](#stream-writers) section below for further
+information.
 
     # Sync requests
     ^/_matrix/client/(v2_alpha|r0|v3)/sync$
@@ -209,7 +212,6 @@ expressions:
     ^/_matrix/federation/v1/user/devices/
     ^/_matrix/federation/v1/get_groups_publicised$
     ^/_matrix/key/v2/query
-    ^/_matrix/federation/unstable/org.matrix.msc2946/spaces/
     ^/_matrix/federation/(v1|unstable/org.matrix.msc2946)/hierarchy/
 
     # Inbound federation transaction request
@@ -222,22 +224,25 @@ expressions:
     ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$
     ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$
     ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$
-    ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$
     ^/_matrix/client/(v1|unstable/org.matrix.msc2946)/rooms/.*/hierarchy$
     ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/devices$
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/query$
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/changes$
+    ^/_matrix/client/(r0|v3|unstable)/account/3pid$
+    ^/_matrix/client/(r0|v3|unstable)/devices$
     ^/_matrix/client/versions$
     ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_groups$
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups$
-    ^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups/
+    ^/_matrix/client/(r0|v3|unstable)/joined_groups$
+    ^/_matrix/client/(r0|v3|unstable)/publicised_groups$
+    ^/_matrix/client/(r0|v3|unstable)/publicised_groups/
     ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
     ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$
     ^/_matrix/client/(api/v1|r0|v3|unstable)/search$
 
+    # Encryption requests
+    ^/_matrix/client/(r0|v3|unstable)/keys/query$
+    ^/_matrix/client/(r0|v3|unstable)/keys/changes$
+    ^/_matrix/client/(r0|v3|unstable)/keys/claim$
+    ^/_matrix/client/(r0|v3|unstable)/room_keys/
+
     # Registration/login requests
     ^/_matrix/client/(api/v1|r0|v3|unstable)/login$
     ^/_matrix/client/(r0|v3|unstable)/register$
@@ -251,6 +256,20 @@ expressions:
     ^/_matrix/client/(api/v1|r0|v3|unstable)/join/
     ^/_matrix/client/(api/v1|r0|v3|unstable)/profile/
 
+    # Device requests
+    ^/_matrix/client/(r0|v3|unstable)/sendToDevice/
+
+    # Account data requests
+    ^/_matrix/client/(r0|v3|unstable)/.*/tags
+    ^/_matrix/client/(r0|v3|unstable)/.*/account_data
+
+    # Receipts requests
+    ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt
+    ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers
+
+    # Presence requests
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
+
 
 Additionally, the following REST endpoints can be handled for GET requests:
 
@@ -330,12 +349,10 @@ Additionally, there is *experimental* support for moving writing of specific
 streams (such as events) off of the main process to a particular worker. (This
 is only supported with Redis-based replication.)
 
-Currently supported streams are `events` and `typing`.
-
 To enable this, the worker must have a HTTP replication listener configured,
-have a `worker_name` and be listed in the `instance_map` config. For example to
-move event persistence off to a dedicated worker, the shared configuration would
-include:
+have a `worker_name` and be listed in the `instance_map` config. The same worker
+can handle multiple streams. For example, to move event persistence off to a
+dedicated worker, the shared configuration would include:
 
 ```yaml
 instance_map:
@@ -347,6 +364,12 @@ stream_writers:
     events: event_persister1
 ```
 
+Some of the streams have associated endpoints which, for maximum efficiency, should
+be routed to the workers handling that stream. See below for the currently supported
+streams and the endpoints associated with them:
+
+##### The `events` stream
+
 The `events` stream also experimentally supports having multiple writers, where
 work is sharded between them by room ID. Note that you *must* restart all worker
 instances when adding or removing event persisters. An example `stream_writers`
@@ -359,6 +382,43 @@ stream_writers:
         - event_persister2
 ```
 
+##### The `typing` stream
+
+The following endpoints should be routed directly to the workers configured as
+stream writers for the `typing` stream:
+
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing
+
+##### The `to_device` stream
+
+The following endpoints should be routed directly to the workers configured as
+stream writers for the `to_device` stream:
+
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
+
+##### The `account_data` stream
+
+The following endpoints should be routed directly to the workers configured as
+stream writers for the `account_data` stream:
+
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
+
+##### The `receipts` stream
+
+The following endpoints should be routed directly to the workers configured as
+stream writers for the `receipts` stream:
+
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
+
+##### The `presence` stream
+
+The following endpoints should be routed directly to the workers configured as
+stream writers for the `presence` stream:
+
+    ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
+
 #### Background tasks
 
 There is also *experimental* support for moving background tasks to a separate
diff --git a/mypy.ini b/mypy.ini
index 610660b9b7..38ff787609 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -75,16 +75,12 @@ exclude = (?x)
    |tests/push/test_presentable_names.py
    |tests/push/test_push_rule_evaluator.py
    |tests/rest/client/test_account.py
-   |tests/rest/client/test_events.py
    |tests/rest/client/test_filter.py
-   |tests/rest/client/test_groups.py
-   |tests/rest/client/test_register.py
    |tests/rest/client/test_report_event.py
    |tests/rest/client/test_rooms.py
    |tests/rest/client/test_third_party_rules.py
    |tests/rest/client/test_transactions.py
    |tests/rest/client/test_typing.py
-   |tests/rest/client/utils.py
    |tests/rest/key/v2/test_remote_key_resource.py
    |tests/rest/media/v1/test_base.py
    |tests/rest/media/v1/test_media_storage.py
@@ -253,7 +249,7 @@ disallow_untyped_defs = True
 [mypy-tests.rest.admin.*]
 disallow_untyped_defs = True
 
-[mypy-tests.rest.client.test_directory]
+[mypy-tests.rest.client.*]
 disallow_untyped_defs = True
 
 [mypy-tests.federation.transport.test_client]
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index c764011d6a..493558ad65 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -35,7 +35,7 @@ CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing y
 https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog"
 
 # If check-newsfragment returns a non-zero exit code, print the contributing guide and exit
-tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1)
+python -m towncrier.check --compare-with=origin/develop || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1)
 
 echo
 echo "--------------------------"
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index e08ffedaf3..0aecb3daf1 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -5,7 +5,7 @@
 # It makes a Synapse image which represents the current checkout,
 # builds a synapse-complement image on top, then runs tests with it.
 #
-# By default the script will fetch the latest Complement master branch and
+# By default the script will fetch the latest Complement main branch and
 # run tests with that. This can be overridden to use a custom Complement
 # checkout by setting the COMPLEMENT_DIR environment variable to the
 # filepath of a local Complement checkout or by setting the COMPLEMENT_REF
@@ -32,7 +32,7 @@ cd "$(dirname $0)/.."
 
 # Check for a user-specified Complement checkout
 if [[ -z "$COMPLEMENT_DIR" ]]; then
-  COMPLEMENT_REF=${COMPLEMENT_REF:-master}
+  COMPLEMENT_REF=${COMPLEMENT_REF:-main}
   echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..."
   wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz
   tar -xzf ${COMPLEMENT_REF}.tar.gz
diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database
index 5c6453d77f..f43676afaa 100755
--- a/scripts/update_synapse_database
+++ b/scripts/update_synapse_database
@@ -44,7 +44,7 @@ class MockHomeserver(HomeServer):
 
 
 def run_background_updates(hs):
-    store = hs.get_datastore()
+    store = hs.get_datastores().main
 
     async def run_background_updates():
         await store.db_pool.updates.run_background_updates(sleep=False)
diff --git a/setup.cfg b/setup.cfg
index a0506572d9..6213f3265b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,3 @@
-[trial]
-test_suite = tests
-
 [check-manifest]
 ignore =
     .git-blame-ignore-revs
@@ -10,12 +7,3 @@ ignore =
     pylint.cfg
     tox.ini
 
-[flake8]
-# see https://pycodestyle.readthedocs.io/en/latest/intro.html#error-codes
-# for error codes. The ones we ignore are:
-#  W503: line break before binary operator
-#  W504: line break after binary operator
-#  E203: whitespace before ':' (which is contrary to pep8?)
-#  E731: do not assign a lambda expression, use a def
-#  E501: Line too long (black enforces this for us)
-ignore=W503,W504,E203,E731,E501
diff --git a/setup.py b/setup.py
index c80cb6f207..26f4650348 100755
--- a/setup.py
+++ b/setup.py
@@ -165,6 +165,7 @@ setup(
         "Programming Language :: Python :: 3.7",
         "Programming Language :: Python :: 3.8",
         "Programming Language :: Python :: 3.9",
+        "Programming Language :: Python :: 3.10",
     ],
     scripts=["synctl"] + glob.glob("scripts/*"),
     cmdclass={"test": TestCommand},
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 903f2e815d..b21e1ed0f3 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.53.0"
+__version__ = "1.54.0rc1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 683241201c..01c32417d8 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -60,7 +60,7 @@ class Auth:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state = hs.get_state_handler()
         self._account_validity_handler = hs.get_account_validity_handler()
 
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 08fe160c98..22348d2d86 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
 
 class AuthBlocking:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
         self._hs_disabled = hs.config.server.hs_disabled
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index d087c816db..cb532d7238 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -22,6 +22,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     TypeVar,
@@ -150,7 +151,7 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
 class Filtering:
     def __init__(self, hs: "HomeServer"):
         self._hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
 
@@ -294,7 +295,7 @@ class FilterCollection:
 class Filter:
     def __init__(self, hs: "HomeServer", filter_json: JsonDict):
         self._hs = hs
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self.filter_json = filter_json
 
         self.limit = filter_json.get("limit", 10)
@@ -361,10 +362,10 @@ class Filter:
             return self._check_fields(field_matchers)
         else:
             content = event.get("content")
-            # Content is assumed to be a dict below, so ensure it is. This should
+            # Content is assumed to be a mapping below, so ensure it is. This should
             # always be true for events, but account_data has been allowed to
             # have non-dict content.
-            if not isinstance(content, dict):
+            if not isinstance(content, Mapping):
                 content = {}
 
             sender = event.get("sender", None)
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index ee51480a9e..334c3d2c17 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -15,13 +15,13 @@ import logging
 import sys
 from typing import Container
 
-from synapse import python_dependencies  # noqa: E402
+from synapse.util import check_dependencies
 
 logger = logging.getLogger(__name__)
 
 try:
-    python_dependencies.check_requirements()
-except python_dependencies.DependencyException as e:
+    check_dependencies.check_requirements()
+except check_dependencies.DependencyException as e:
     sys.stderr.writelines(
         e.message  # noqa: B306, DependencyException.message is a property
     )
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 452c0c09d5..3e59805baa 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -448,7 +448,7 @@ async def start(hs: "HomeServer") -> None:
 
     # It is now safe to start your Synapse.
     hs.start_listening()
-    hs.get_datastore().db_pool.start_profiling()
+    hs.get_datastores().main.db_pool.start_profiling()
     hs.get_pusherpool().start()
 
     # Log when we start the shut down process.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index aadc882bf8..1536a42723 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -142,7 +142,7 @@ class KeyUploadServlet(RestServlet):
         """
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.http_client = hs.get_simple_http_client()
         self.main_uri = hs.config.worker.worker_main_http_uri
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index bfb30003c2..a6789a840e 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -59,7 +59,6 @@ from synapse.http.server import (
 from synapse.http.site import SynapseSite
 from synapse.logging.context import LoggingContext
 from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
-from synapse.python_dependencies import check_requirements
 from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.rest import ClientRestResource
@@ -70,6 +69,7 @@ from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import well_known_resource
 from synapse.server import HomeServer
 from synapse.storage import DataStore
+from synapse.util.check_dependencies import check_requirements
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.module_loader import load_module
 
@@ -372,7 +372,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
 
         await _base.start(hs)
 
-        hs.get_datastore().db_pool.updates.start_doing_background_updates()
+        hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
 
     register_start(start)
 
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 899dba5c3d..40dbdace8e 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -82,7 +82,7 @@ async def phone_stats_home(
     # General statistics
     #
 
-    store = hs.get_datastore()
+    store = hs.get_datastores().main
 
     stats["homeserver"] = hs.config.server.server_name
     stats["server_context"] = hs.config.server.server_context
@@ -170,18 +170,22 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
     # Rather than update on per session basis, batch up the requests.
     # If you increase the loop period, the accuracy of user_daily_visits
     # table will decrease
-    clock.looping_call(hs.get_datastore().generate_user_daily_visits, 5 * 60 * 1000)
+    clock.looping_call(
+        hs.get_datastores().main.generate_user_daily_visits, 5 * 60 * 1000
+    )
 
     # monthly active user limiting functionality
-    clock.looping_call(hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60)
-    hs.get_datastore().reap_monthly_active_users()
+    clock.looping_call(
+        hs.get_datastores().main.reap_monthly_active_users, 1000 * 60 * 60
+    )
+    hs.get_datastores().main.reap_monthly_active_users()
 
     @wrap_as_background_process("generate_monthly_active_users")
     async def generate_monthly_active_users() -> None:
         current_mau_count = 0
         current_mau_count_by_service = {}
         reserved_users: Sized = ()
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
             current_mau_count = await store.get_monthly_active_count()
             current_mau_count_by_service = (
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index a340a8c9c7..4d3f8e4923 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -31,6 +31,14 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# Type for the `device_one_time_key_counts` field in an appservice transaction
+#   user ID -> {device ID -> {algorithm -> count}}
+TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]]
+
+# Type for the `device_unused_fallback_keys` field in an appservice transaction
+#   user ID -> {device ID -> [algorithm]}
+TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]]
+
 
 class ApplicationServiceState(Enum):
     DOWN = "down"
@@ -72,6 +80,7 @@ class ApplicationService:
         rate_limited: bool = True,
         ip_range_whitelist: Optional[IPSet] = None,
         supports_ephemeral: bool = False,
+        msc3202_transaction_extensions: bool = False,
     ):
         self.token = token
         self.url = (
@@ -84,6 +93,7 @@ class ApplicationService:
         self.id = id
         self.ip_range_whitelist = ip_range_whitelist
         self.supports_ephemeral = supports_ephemeral
+        self.msc3202_transaction_extensions = msc3202_transaction_extensions
 
         if "|" in self.id:
             raise Exception("application service ID cannot contain '|' character")
@@ -339,12 +349,16 @@ class AppServiceTransaction:
         events: List[EventBase],
         ephemeral: List[JsonDict],
         to_device_messages: List[JsonDict],
+        one_time_key_counts: TransactionOneTimeKeyCounts,
+        unused_fallback_keys: TransactionUnusedFallbackKeys,
     ):
         self.service = service
         self.id = id
         self.events = events
         self.ephemeral = ephemeral
         self.to_device_messages = to_device_messages
+        self.one_time_key_counts = one_time_key_counts
+        self.unused_fallback_keys = unused_fallback_keys
 
     async def send(self, as_api: "ApplicationServiceApi") -> bool:
         """Sends this transaction using the provided AS API interface.
@@ -359,6 +373,8 @@ class AppServiceTransaction:
             events=self.events,
             ephemeral=self.ephemeral,
             to_device_messages=self.to_device_messages,
+            one_time_key_counts=self.one_time_key_counts,
+            unused_fallback_keys=self.unused_fallback_keys,
             txn_id=self.id,
         )
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 73be7ff3d4..a0ea958af6 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -19,6 +19,11 @@ from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
+from synapse.appservice import (
+    ApplicationService,
+    TransactionOneTimeKeyCounts,
+    TransactionUnusedFallbackKeys,
+)
 from synapse.events import EventBase
 from synapse.events.utils import serialize_event
 from synapse.http.client import SimpleHttpClient
@@ -26,7 +31,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
-    from synapse.appservice import ApplicationService
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -219,6 +223,8 @@ class ApplicationServiceApi(SimpleHttpClient):
         events: List[EventBase],
         ephemeral: List[JsonDict],
         to_device_messages: List[JsonDict],
+        one_time_key_counts: TransactionOneTimeKeyCounts,
+        unused_fallback_keys: TransactionUnusedFallbackKeys,
         txn_id: Optional[int] = None,
     ) -> bool:
         """
@@ -252,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
         uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
 
         # Never send ephemeral events to appservices that do not support it
-        body: Dict[str, List[JsonDict]] = {"events": serialized_events}
+        body: JsonDict = {"events": serialized_events}
         if service.supports_ephemeral:
             body.update(
                 {
@@ -262,6 +268,16 @@ class ApplicationServiceApi(SimpleHttpClient):
                 }
             )
 
+        if service.msc3202_transaction_extensions:
+            if one_time_key_counts:
+                body[
+                    "org.matrix.msc3202.device_one_time_key_counts"
+                ] = one_time_key_counts
+            if unused_fallback_keys:
+                body[
+                    "org.matrix.msc3202.device_unused_fallback_keys"
+                ] = unused_fallback_keys
+
         try:
             await self.put_json(
                 uri=uri,
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index c42fa32fff..72417151ba 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -54,12 +54,19 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    Iterable,
     List,
     Optional,
     Set,
+    Tuple,
 )
 
-from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.appservice import (
+    ApplicationService,
+    ApplicationServiceState,
+    TransactionOneTimeKeyCounts,
+    TransactionUnusedFallbackKeys,
+)
 from synapse.appservice.api import ApplicationServiceApi
 from synapse.events import EventBase
 from synapse.logging.context import run_in_background
@@ -92,11 +99,11 @@ class ApplicationServiceScheduler:
 
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.as_api = hs.get_application_service_api()
 
         self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
-        self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
+        self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs)
 
     async def start(self) -> None:
         logger.info("Starting appservice scheduler")
@@ -153,7 +160,9 @@ class _ServiceQueuer:
     appservice at a given time.
     """
 
-    def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
+    def __init__(
+        self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer"
+    ):
         # dict of {service_id: [events]}
         self.queued_events: Dict[str, List[EventBase]] = {}
         # dict of {service_id: [events]}
@@ -165,6 +174,10 @@ class _ServiceQueuer:
         self.requests_in_flight: Set[str] = set()
         self.txn_ctrl = txn_ctrl
         self.clock = clock
+        self._msc3202_transaction_extensions_enabled: bool = (
+            hs.config.experimental.msc3202_transaction_extensions
+        )
+        self._store = hs.get_datastores().main
 
     def start_background_request(self, service: ApplicationService) -> None:
         # start a sender for this appservice if we don't already have one
@@ -202,15 +215,84 @@ class _ServiceQueuer:
                 if not events and not ephemeral and not to_device_messages_to_send:
                     return
 
+                one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
+                unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None
+
+                if (
+                    self._msc3202_transaction_extensions_enabled
+                    and service.msc3202_transaction_extensions
+                ):
+                    # Compute the one-time key counts and fallback key usage states
+                    # for the users which are mentioned in this transaction,
+                    # as well as the appservice's sender.
+                    (
+                        one_time_key_counts,
+                        unused_fallback_keys,
+                    ) = await self._compute_msc3202_otk_counts_and_fallback_keys(
+                        service, events, ephemeral, to_device_messages_to_send
+                    )
+
                 try:
                     await self.txn_ctrl.send(
-                        service, events, ephemeral, to_device_messages_to_send
+                        service,
+                        events,
+                        ephemeral,
+                        to_device_messages_to_send,
+                        one_time_key_counts,
+                        unused_fallback_keys,
                     )
                 except Exception:
                     logger.exception("AS request failed")
         finally:
             self.requests_in_flight.discard(service.id)
 
+    async def _compute_msc3202_otk_counts_and_fallback_keys(
+        self,
+        service: ApplicationService,
+        events: Iterable[EventBase],
+        ephemerals: Iterable[JsonDict],
+        to_device_messages: Iterable[JsonDict],
+    ) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]:
+        """
+        Given a list of the events, ephemeral messages and to-device messages,
+        - first computes a list of application services users that may have
+          interesting updates to the one-time key counts or fallback key usage.
+        - then computes one-time key counts and fallback key usages for those users.
+        Given a list of application service users that are interesting,
+        compute one-time key counts and fallback key usages for the users.
+        """
+
+        # Set of 'interesting' users who may have updates
+        users: Set[str] = set()
+
+        # The sender is always included
+        users.add(service.sender)
+
+        # All AS users that would receive the PDUs or EDUs sent to these rooms
+        # are classed as 'interesting'.
+        rooms_of_interesting_users: Set[str] = set()
+        # PDUs
+        rooms_of_interesting_users.update(event.room_id for event in events)
+        # EDUs
+        rooms_of_interesting_users.update(
+            ephemeral["room_id"] for ephemeral in ephemerals
+        )
+
+        # Look up the AS users in those rooms
+        for room_id in rooms_of_interesting_users:
+            users.update(
+                await self._store.get_app_service_users_in_room(room_id, service)
+            )
+
+        # Add recipients of to-device messages.
+        # device_message["user_id"] is the ID of the recipient.
+        users.update(device_message["user_id"] for device_message in to_device_messages)
+
+        # Compute and return the counts / fallback key usage states
+        otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users)
+        unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users)
+        return otk_counts, unused_fbks
+
 
 class _TransactionController:
     """Transaction manager.
@@ -238,6 +320,8 @@ class _TransactionController:
         events: List[EventBase],
         ephemeral: Optional[List[JsonDict]] = None,
         to_device_messages: Optional[List[JsonDict]] = None,
+        one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
+        unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
     ) -> None:
         """
         Create a transaction with the given data and send to the provided
@@ -248,6 +332,10 @@ class _TransactionController:
             events: The persistent events to include in the transaction.
             ephemeral: The ephemeral events to include in the transaction.
             to_device_messages: The to-device messages to include in the transaction.
+            one_time_key_counts: Counts of remaining one-time keys for relevant
+                appservice devices in the transaction.
+            unused_fallback_keys: Lists of unused fallback keys for relevant
+                appservice devices in the transaction.
         """
         try:
             txn = await self.store.create_appservice_txn(
@@ -255,6 +343,8 @@ class _TransactionController:
                 events=events,
                 ephemeral=ephemeral or [],
                 to_device_messages=to_device_messages or [],
+                one_time_key_counts=one_time_key_counts or {},
+                unused_fallback_keys=unused_fallback_keys or {},
             )
             service_is_up = await self._is_service_up(service)
             if service_is_up:
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 7fad2e0422..439bfe1526 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -166,6 +166,16 @@ def _load_appservice(
 
     supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
 
+    # Opt-in flag for the MSC3202-specific transactional behaviour.
+    # When enabled, appservice transactions contain the following information:
+    #  - device One-Time Key counts
+    #  - device unused fallback key usage states
+    msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
+    if not isinstance(msc3202_transaction_extensions, bool):
+        raise ValueError(
+            "The `org.matrix.msc3202` option should be true or false if specified."
+        )
+
     return ApplicationService(
         token=as_info["as_token"],
         hostname=hostname,
@@ -174,8 +184,9 @@ def _load_appservice(
         hs_token=as_info["hs_token"],
         sender=user_id,
         id=as_info["id"],
-        supports_ephemeral=supports_ephemeral,
         protocols=protocols,
         rate_limited=rate_limited,
         ip_range_whitelist=ip_range_whitelist,
+        supports_ephemeral=supports_ephemeral,
+        msc3202_transaction_extensions=msc3202_transaction_extensions,
     )
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index 387ac6d115..9a68da9c33 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -20,7 +20,7 @@ from typing import Callable, Dict, Optional
 
 import attr
 
-from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import DependencyException, check_requirements
 
 from ._base import Config, ConfigError
 
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index bcdeb9ee23..41338b39df 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -47,11 +47,6 @@ class ExperimentalConfig(Config):
         # MSC3030 (Jump to date API endpoint)
         self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
 
-        # The portion of MSC3202 which is related to device masquerading.
-        self.msc3202_device_masquerading_enabled: bool = experimental.get(
-            "msc3202_device_masquerading", False
-        )
-
         # MSC2409 (this setting only relates to optionally sending to-device messages).
         # Presence, typing and read receipt EDUs are already sent to application services that
         # have opted in to receive them. If enabled, this adds to-device messages to that list.
@@ -59,9 +54,23 @@ class ExperimentalConfig(Config):
             "msc2409_to_device_messages_enabled", False
         )
 
+        # The portion of MSC3202 which is related to device masquerading.
+        self.msc3202_device_masquerading_enabled: bool = experimental.get(
+            "msc3202_device_masquerading", False
+        )
+
+        # Portion of MSC3202 related to transaction extensions:
+        # sending one-time key counts and fallback key usage to application services.
+        self.msc3202_transaction_extensions: bool = experimental.get(
+            "msc3202_transaction_extensions", False
+        )
+
         # MSC3706 (server-side support for partial state in /send_join responses)
         self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
 
         # experimental support for faster joins over federation (msc2775, msc3706)
         # requires a target server with msc3706_enabled enabled.
         self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
+
+        # MSC3720 (Account status endpoint)
+        self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 1cc26e7578..f62292ecf6 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -15,7 +15,7 @@
 
 import attr
 
-from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import DependencyException, check_requirements
 
 from ._base import Config, ConfigError
 
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index e783b11315..f7e4f9ef22 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -20,11 +20,11 @@ import attr
 
 from synapse.config._util import validate_config
 from synapse.config.sso import SsoAttributeRequirement
-from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.types import JsonDict
 from synapse.util.module_loader import load_module
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
+from ..util.check_dependencies import DependencyException, check_requirements
 from ._base import Config, ConfigError, read_file
 
 DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc.JinjaOidcMappingProvider"
diff --git a/synapse/config/redis.py b/synapse/config/redis.py
index 33104af734..bdb1aac3a2 100644
--- a/synapse/config/redis.py
+++ b/synapse/config/redis.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from synapse.config._base import Config
-from synapse.python_dependencies import check_requirements
+from synapse.util.check_dependencies import check_requirements
 
 
 class RedisConfig(Config):
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 1980351e77..0a0d901bfb 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -20,8 +20,8 @@ from urllib.request import getproxies_environment  # type: ignore
 import attr
 
 from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
-from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.types import JsonDict
+from synapse.util.check_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module
 
 from ._base import Config, ConfigError
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index ec9d9f65e7..43c456d5c6 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -17,8 +17,8 @@ import logging
 from typing import Any, List, Set
 
 from synapse.config.sso import SsoAttributeRequirement
-from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.types import JsonDict
+from synapse.util.check_dependencies import DependencyException, check_requirements
 from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 21b9a88353..7aff618ea6 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -14,7 +14,7 @@
 
 from typing import Set
 
-from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.check_dependencies import DependencyException, check_requirements
 
 from ._base import Config, ConfigError
 
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 72d4a69aac..93d56c077a 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -476,7 +476,7 @@ class StoreKeyFetcher(KeyFetcher):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def _fetch_keys(
         self, keys_to_fetch: List[_FetchKeyRequest]
@@ -498,7 +498,7 @@ class BaseV2KeyFetcher(KeyFetcher):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.config = hs.config
 
     async def process_v2_response(
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index eca00bc975..621a3efccc 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -374,9 +374,9 @@ def _is_membership_change_allowed(
         return
 
     # Require the user to be in the room for membership changes other than join/knock.
-    if Membership.JOIN != membership and (
-        RoomVersion.msc2403_knocking and Membership.KNOCK != membership
-    ):
+    # Note that the room version check for knocking is done implicitly by `caller_knocked`
+    # and the ability to set a membership of `knock` in the first place.
+    if Membership.JOIN != membership and Membership.KNOCK != membership:
         # If the user has been invited or has knocked, they are allowed to change their
         # membership event to leave
         if (
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index eb39e0ae32..1ea1bb7d37 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -189,7 +189,7 @@ class EventBuilderFactory:
         self.hostname = hs.hostname
         self.signing_key = hs.signing_key
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 5833fee25f..46042b2bf7 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -101,6 +101,9 @@ class EventContext:
 
             As with _current_state_ids, this is a private attribute. It should be
             accessed via get_prev_state_ids.
+
+        partial_state: if True, we may be storing this event with a temporary,
+            incomplete state.
     """
 
     rejected: Union[bool, str] = False
@@ -113,12 +116,15 @@ class EventContext:
     _current_state_ids: Optional[StateMap[str]] = None
     _prev_state_ids: Optional[StateMap[str]] = None
 
+    partial_state: bool = False
+
     @staticmethod
     def with_state(
         state_group: Optional[int],
         state_group_before_event: Optional[int],
         current_state_ids: Optional[StateMap[str]],
         prev_state_ids: Optional[StateMap[str]],
+        partial_state: bool,
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ) -> "EventContext":
@@ -129,6 +135,7 @@ class EventContext:
             state_group_before_event=state_group_before_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
+            partial_state=partial_state,
         )
 
     @staticmethod
@@ -170,6 +177,7 @@ class EventContext:
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
+            "partial_state": self.partial_state,
         }
 
     @staticmethod
@@ -196,6 +204,7 @@ class EventContext:
             prev_group=input["prev_group"],
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
+            partial_state=input.get("partial_state", False),
         )
 
         app_service_id = input["app_service_id"]
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 1bb8ca7145..dd3104faf3 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tupl
 from synapse.api.errors import ModuleFailedException, SynapseError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
+from synapse.storage.roommember import ProfileInfo
 from synapse.types import Requester, StateMap
 from synapse.util.async_helpers import maybe_awaitable
 
@@ -37,6 +38,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
     [str, StateMap[EventBase], str], Awaitable[bool]
 ]
 ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
+ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
+ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
 
 
 def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@@ -143,7 +146,7 @@ class ThirdPartyEventRules:
     def __init__(self, hs: "HomeServer"):
         self.third_party_rules = None
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
         self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
@@ -154,6 +157,10 @@ class ThirdPartyEventRules:
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = []
         self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
+        self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = []
+        self._on_user_deactivation_status_changed_callbacks: List[
+            ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
+        ] = []
 
     def register_third_party_rules_callbacks(
         self,
@@ -166,6 +173,8 @@ class ThirdPartyEventRules:
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = None,
         on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
+        on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None,
+        on_deactivation: Optional[ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK] = None,
     ) -> None:
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
@@ -187,6 +196,12 @@ class ThirdPartyEventRules:
         if on_new_event is not None:
             self._on_new_event_callbacks.append(on_new_event)
 
+        if on_profile_update is not None:
+            self._on_profile_update_callbacks.append(on_profile_update)
+
+        if on_deactivation is not None:
+            self._on_user_deactivation_status_changed_callbacks.append(on_deactivation)
+
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
     ) -> Tuple[bool, Optional[dict]]:
@@ -334,9 +349,6 @@ class ThirdPartyEventRules:
 
         Args:
             event_id: The ID of the event.
-
-        Raises:
-            ModuleFailureError if a callback raised any exception.
         """
         # Bail out early without hitting the store if we don't have any callbacks
         if len(self._on_new_event_callbacks) == 0:
@@ -370,3 +382,41 @@ class ThirdPartyEventRules:
             state_events[key] = room_state_events[event_id]
 
         return state_events
+
+    async def on_profile_update(
+        self, user_id: str, new_profile: ProfileInfo, by_admin: bool, deactivation: bool
+    ) -> None:
+        """Called after the global profile of a user has been updated. Does not include
+        per-room profile changes.
+
+        Args:
+            user_id: The user whose profile was changed.
+            new_profile: The updated profile for the user.
+            by_admin: Whether the profile update was performed by a server admin.
+            deactivation: Whether this change was made while deactivating the user.
+        """
+        for callback in self._on_profile_update_callbacks:
+            try:
+                await callback(user_id, new_profile, by_admin, deactivation)
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )
+
+    async def on_user_deactivation_status_changed(
+        self, user_id: str, deactivated: bool, by_admin: bool
+    ) -> None:
+        """Called after a user has been deactivated or reactivated.
+
+        Args:
+            user_id: The deactivated user.
+            deactivated: Whether the user is now deactivated.
+            by_admin: Whether the deactivation was performed by a server admin.
+        """
+        for callback in self._on_user_deactivation_status_changed_callbacks:
+            try:
+                await callback(user_id, deactivated, by_admin)
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index fab6da3c08..41ac49fdc8 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -39,7 +39,7 @@ class FederationBase:
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
         self.spam_checker = hs.get_spam_checker()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self._clock = hs.get_clock()
 
     async def _check_sigs_and_hash(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c2997997da..64e595e748 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,7 +56,7 @@ from synapse.api.room_versions import (
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.transport.client import SendJoinResponse
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
@@ -615,11 +615,15 @@ class FederationClient(FederationBase):
             synapse_error = e.to_synapse_error()
         # There is no good way to detect an "unknown" endpoint.
         #
-        # Dendrite returns a 404 (with no body); synapse returns a 400
+        # Dendrite returns a 404 (with a body of "404 page not found");
+        # Conduit returns a 404 (with no body); and Synapse returns a 400
         # with M_UNRECOGNISED.
-        return e.code == 404 or (
-            e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED
-        )
+        #
+        # This needs to be rather specific as some endpoints truly do return 404
+        # errors.
+        return (
+            e.code == 404 and (not e.response or e.response == b"404 page not found")
+        ) or (e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED)
 
     async def _try_destination_list(
         self,
@@ -1002,7 +1006,7 @@ class FederationClient(FederationBase):
             )
         except HttpResponseException as e:
             # If an error is received that is due to an unrecognised endpoint,
-            # fallback to the v1 endpoint. Otherwise consider it a legitmate error
+            # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
             # and raise.
             if not self._is_unknown_endpoint(e):
                 raise
@@ -1071,7 +1075,7 @@ class FederationClient(FederationBase):
         except HttpResponseException as e:
             # If an error is received that is due to an unrecognised endpoint,
             # fallback to the v1 endpoint if the room uses old-style event IDs.
-            # Otherwise consider it a legitmate error and raise.
+            # Otherwise, consider it a legitimate error and raise.
             err = e.to_synapse_error()
             if self._is_unknown_endpoint(e, err):
                 if room_version.event_format != EventFormatVersions.V1:
@@ -1132,7 +1136,7 @@ class FederationClient(FederationBase):
             )
         except HttpResponseException as e:
             # If an error is received that is due to an unrecognised endpoint,
-            # fallback to the v1 endpoint. Otherwise consider it a legitmate error
+            # fallback to the v1 endpoint. Otherwise, consider it a legitimate error
             # and raise.
             if not self._is_unknown_endpoint(e):
                 raise
@@ -1358,61 +1362,6 @@ class FederationClient(FederationBase):
         # server doesn't give it to us.
         return None
 
-    async def get_space_summary(
-        self,
-        destinations: Iterable[str],
-        room_id: str,
-        suggested_only: bool,
-        max_rooms_per_space: Optional[int],
-        exclude_rooms: List[str],
-    ) -> "FederationSpaceSummaryResult":
-        """
-        Call other servers to get a summary of the given space
-
-
-        Args:
-            destinations: The remote servers. We will try them in turn, omitting any
-                that have been blacklisted.
-
-            room_id: ID of the space to be queried
-
-            suggested_only:  If true, ask the remote server to only return children
-                with the "suggested" flag set
-
-            max_rooms_per_space: A limit on the number of children to return for each
-                space
-
-            exclude_rooms: A list of room IDs to tell the remote server to skip
-
-        Returns:
-            a parsed FederationSpaceSummaryResult
-
-        Raises:
-            SynapseError if we were unable to get a valid summary from any of the
-               remote servers
-        """
-
-        async def send_request(destination: str) -> FederationSpaceSummaryResult:
-            res = await self.transport_layer.get_space_summary(
-                destination=destination,
-                room_id=room_id,
-                suggested_only=suggested_only,
-                max_rooms_per_space=max_rooms_per_space,
-                exclude_rooms=exclude_rooms,
-            )
-
-            try:
-                return FederationSpaceSummaryResult.from_json_dict(res)
-            except ValueError as e:
-                raise InvalidResponseError(str(e))
-
-        return await self._try_destination_list(
-            "fetch space summary",
-            destinations,
-            send_request,
-            failover_on_unknown_endpoint=True,
-        )
-
     async def get_room_hierarchy(
         self,
         destinations: Iterable[str],
@@ -1458,8 +1407,8 @@ class FederationClient(FederationBase):
                 )
             except HttpResponseException as e:
                 # If an error is received that is due to an unrecognised endpoint,
-                # fallback to the unstable endpoint. Otherwise consider it a
-                # legitmate error and raise.
+                # fallback to the unstable endpoint. Otherwise, consider it a
+                # legitimate error and raise.
                 if not self._is_unknown_endpoint(e):
                     raise
 
@@ -1484,10 +1433,8 @@ class FederationClient(FederationBase):
             if any(not isinstance(e, dict) for e in children_state):
                 raise InvalidResponseError("Invalid event in 'children_state' list")
             try:
-                [
-                    FederationSpaceSummaryEventResult.from_json_dict(e)
-                    for e in children_state
-                ]
+                for child_state in children_state:
+                    _validate_hierarchy_event(child_state)
             except ValueError as e:
                 raise InvalidResponseError(str(e))
 
@@ -1509,62 +1456,12 @@ class FederationClient(FederationBase):
 
             return room, children_state, children, inaccessible_children
 
-        try:
-            result = await self._try_destination_list(
-                "fetch room hierarchy",
-                destinations,
-                send_request,
-                failover_on_unknown_endpoint=True,
-            )
-        except SynapseError as e:
-            # If an unexpected error occurred, re-raise it.
-            if e.code != 502:
-                raise
-
-            logger.debug(
-                "Couldn't fetch room hierarchy, falling back to the spaces API"
-            )
-
-            # Fallback to the old federation API and translate the results if
-            # no servers implement the new API.
-            #
-            # The algorithm below is a bit inefficient as it only attempts to
-            # parse information for the requested room, but the legacy API may
-            # return additional layers.
-            legacy_result = await self.get_space_summary(
-                destinations,
-                room_id,
-                suggested_only,
-                max_rooms_per_space=None,
-                exclude_rooms=[],
-            )
-
-            # Find the requested room in the response (and remove it).
-            for _i, room in enumerate(legacy_result.rooms):
-                if room.get("room_id") == room_id:
-                    break
-            else:
-                # The requested room was not returned, nothing we can do.
-                raise
-            requested_room = legacy_result.rooms.pop(_i)
-
-            # Find any children events of the requested room.
-            children_events = []
-            children_room_ids = set()
-            for event in legacy_result.events:
-                if event.room_id == room_id:
-                    children_events.append(event.data)
-                    children_room_ids.add(event.state_key)
-
-            # Find the children rooms.
-            children = []
-            for room in legacy_result.rooms:
-                if room.get("room_id") in children_room_ids:
-                    children.append(room)
-
-            # It isn't clear from the response whether some of the rooms are
-            # not accessible.
-            result = (requested_room, children_events, children, ())
+        result = await self._try_destination_list(
+            "fetch room hierarchy",
+            destinations,
+            send_request,
+            failover_on_unknown_endpoint=True,
+        )
 
         # Cache the result to avoid fetching data over federation every time.
         self._get_room_hierarchy_cache[(room_id, suggested_only)] = result
@@ -1610,6 +1507,64 @@ class FederationClient(FederationBase):
         except ValueError as e:
             raise InvalidResponseError(str(e))
 
+    async def get_account_status(
+        self, destination: str, user_ids: List[str]
+    ) -> Tuple[JsonDict, List[str]]:
+        """Retrieves account statuses for a given list of users on a given remote
+        homeserver.
+
+        If the request fails for any reason, all user IDs for this destination are marked
+        as failed.
+
+        Args:
+            destination: the destination to contact
+            user_ids: the user ID(s) for which to request account status(es)
+
+        Returns:
+            The account statuses, as well as the list of user IDs for which it was not
+            possible to retrieve a status.
+        """
+        try:
+            res = await self.transport_layer.get_account_status(destination, user_ids)
+        except Exception:
+            # If the query failed for any reason, mark all the users as failed.
+            return {}, user_ids
+
+        statuses = res.get("account_statuses", {})
+        failures = res.get("failures", [])
+
+        if not isinstance(statuses, dict) or not isinstance(failures, list):
+            # Make sure we're not feeding back malformed data back to the caller.
+            logger.warning(
+                "Destination %s responded with malformed data to account_status query",
+                destination,
+            )
+            return {}, user_ids
+
+        for user_id in user_ids:
+            # Any account whose status is missing is a user we failed to receive the
+            # status of.
+            if user_id not in statuses and user_id not in failures:
+                failures.append(user_id)
+
+        # Filter out any user ID that doesn't belong to the remote server that sent its
+        # status (or failure).
+        def filter_user_id(user_id: str) -> bool:
+            try:
+                return UserID.from_string(user_id).domain == destination
+            except SynapseError:
+                # If the user ID doesn't parse, ignore it.
+                return False
+
+        filtered_statuses = dict(
+            # item is a (key, value) tuple, so item[0] is the user ID.
+            filter(lambda item: filter_user_id(item[0]), statuses.items())
+        )
+
+        filtered_failures = list(filter(filter_user_id, failures))
+
+        return filtered_statuses, filtered_failures
+
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class TimestampToEventResponse:
@@ -1648,89 +1603,34 @@ class TimestampToEventResponse:
         return cls(event_id, origin_server_ts, d)
 
 
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class FederationSpaceSummaryEventResult:
-    """Represents a single event in the result of a successful get_space_summary call.
-
-    It's essentially just a serialised event object, but we do a bit of parsing and
-    validation in `from_json_dict` and store some of the validated properties in
-    object attributes.
-    """
-
-    event_type: str
-    room_id: str
-    state_key: str
-    via: Sequence[str]
-
-    # the raw data, including the above keys
-    data: JsonDict
-
-    @classmethod
-    def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult":
-        """Parse an event within the result of a /spaces/ request
-
-        Args:
-            d: json object to be parsed
-
-        Raises:
-            ValueError if d is not a valid event
-        """
+def _validate_hierarchy_event(d: JsonDict) -> None:
+    """Validate an event within the result of a /hierarchy request
 
-        event_type = d.get("type")
-        if not isinstance(event_type, str):
-            raise ValueError("Invalid event: 'event_type' must be a str")
+    Args:
+        d: json object to be parsed
 
-        room_id = d.get("room_id")
-        if not isinstance(room_id, str):
-            raise ValueError("Invalid event: 'room_id' must be a str")
-
-        state_key = d.get("state_key")
-        if not isinstance(state_key, str):
-            raise ValueError("Invalid event: 'state_key' must be a str")
-
-        content = d.get("content")
-        if not isinstance(content, dict):
-            raise ValueError("Invalid event: 'content' must be a dict")
-
-        via = content.get("via")
-        if not isinstance(via, Sequence):
-            raise ValueError("Invalid event: 'via' must be a list")
-        if any(not isinstance(v, str) for v in via):
-            raise ValueError("Invalid event: 'via' must be a list of strings")
-
-        return cls(event_type, room_id, state_key, via, d)
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class FederationSpaceSummaryResult:
-    """Represents the data returned by a successful get_space_summary call."""
+    Raises:
+        ValueError if d is not a valid event
+    """
 
-    rooms: List[JsonDict]
-    events: Sequence[FederationSpaceSummaryEventResult]
+    event_type = d.get("type")
+    if not isinstance(event_type, str):
+        raise ValueError("Invalid event: 'event_type' must be a str")
 
-    @classmethod
-    def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult":
-        """Parse the result of a /spaces/ request
+    room_id = d.get("room_id")
+    if not isinstance(room_id, str):
+        raise ValueError("Invalid event: 'room_id' must be a str")
 
-        Args:
-            d: json object to be parsed
+    state_key = d.get("state_key")
+    if not isinstance(state_key, str):
+        raise ValueError("Invalid event: 'state_key' must be a str")
 
-        Raises:
-            ValueError if d is not a valid /spaces/ response
-        """
-        rooms = d.get("rooms")
-        if not isinstance(rooms, List):
-            raise ValueError("'rooms' must be a list")
-        if any(not isinstance(r, dict) for r in rooms):
-            raise ValueError("Invalid room in 'rooms' list")
-
-        events = d.get("events")
-        if not isinstance(events, Sequence):
-            raise ValueError("'events' must be a list")
-        if any(not isinstance(e, dict) for e in events):
-            raise ValueError("Invalid event in 'events' list")
-        parsed_events = [
-            FederationSpaceSummaryEventResult.from_json_dict(e) for e in events
-        ]
+    content = d.get("content")
+    if not isinstance(content, dict):
+        raise ValueError("Invalid event: 'content' must be a dict")
 
-        return cls(rooms, parsed_events)
+    via = content.get("via")
+    if not isinstance(via, Sequence):
+        raise ValueError("Invalid event: 'via' must be a list")
+    if any(not isinstance(v, str) for v in via):
+        raise ValueError("Invalid event: 'via' must be a list of strings")
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 720d7bd74d..6106a486d1 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -228,7 +228,7 @@ class FederationSender(AbstractFederationSender):
         self.hs = hs
         self.server_name = hs.hostname
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state = hs.get_state_handler()
 
         self.clock = hs.get_clock()
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index c3132f7319..c8768f22bc 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -76,7 +76,7 @@ class PerDestinationQueue:
     ):
         self._server_name = hs.hostname
         self._clock = hs.get_clock()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._transaction_manager = transaction_manager
         self._instance_name = hs.get_instance_name()
         self._federation_shard_config = hs.config.worker.federation_shard_config
@@ -381,9 +381,8 @@ class PerDestinationQueue:
                 )
             )
 
-        last_successful_stream_ordering = self._last_successful_stream_ordering
-
-        if last_successful_stream_ordering is None:
+        _tmp_last_successful_stream_ordering = self._last_successful_stream_ordering
+        if _tmp_last_successful_stream_ordering is None:
             # if it's still None, then this means we don't have the information
             # in our database ­ we haven't successfully sent a PDU to this server
             # (at least since the introduction of the feature tracking
@@ -393,6 +392,8 @@ class PerDestinationQueue:
             self._catching_up = False
             return
 
+        last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering
+
         # get at most 50 catchup room/PDUs
         while True:
             event_ids = await self._store.get_catch_up_room_event_ids(
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 742ee57255..0c1cad86ab 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -53,7 +53,7 @@ class TransactionManager:
     def __init__(self, hs: "synapse.server.HomeServer"):
         self._server_name = hs.hostname
         self.clock = hs.get_clock()  # nb must be called this for @measure_func
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._transaction_actions = TransactionActions(self._store)
         self._transport_layer = hs.get_federation_transport_client()
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 7e510e224a..de6e5f44fe 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -258,8 +258,9 @@ class TransportLayerClient:
         args: dict,
         retry_on_dns_fail: bool,
         ignore_backoff: bool = False,
+        prefix: str = FEDERATION_V1_PREFIX,
     ) -> JsonDict:
-        path = _create_v1_path("/query/%s", query_type)
+        path = _create_path(prefix, "/query/%s", query_type)
 
         return await self.client.get_json(
             destination=destination,
@@ -1178,39 +1179,6 @@ class TransportLayerClient:
 
         return await self.client.get_json(destination=destination, path=path)
 
-    async def get_space_summary(
-        self,
-        destination: str,
-        room_id: str,
-        suggested_only: bool,
-        max_rooms_per_space: Optional[int],
-        exclude_rooms: List[str],
-    ) -> JsonDict:
-        """
-        Args:
-            destination: The remote server
-            room_id: The room ID to ask about.
-            suggested_only: if True, only suggested rooms will be returned
-            max_rooms_per_space: an optional limit to the number of children to be
-               returned per space
-            exclude_rooms: a list of any rooms we can skip
-        """
-        # TODO When switching to the stable endpoint, use GET instead of POST.
-        path = _create_path(
-            FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/spaces/%s", room_id
-        )
-
-        params = {
-            "suggested_only": suggested_only,
-            "exclude_rooms": exclude_rooms,
-        }
-        if max_rooms_per_space is not None:
-            params["max_rooms_per_space"] = max_rooms_per_space
-
-        return await self.client.post_json(
-            destination=destination, path=path, data=params
-        )
-
     async def get_room_hierarchy(
         self, destination: str, room_id: str, suggested_only: bool
     ) -> JsonDict:
@@ -1247,6 +1215,22 @@ class TransportLayerClient:
             args={"suggested_only": "true" if suggested_only else "false"},
         )
 
+    async def get_account_status(
+        self, destination: str, user_ids: List[str]
+    ) -> JsonDict:
+        """
+        Args:
+            destination: The remote server.
+            user_ids: The user ID(s) for which to request account status(es).
+        """
+        path = _create_path(
+            FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status"
+        )
+
+        return await self.client.post_json(
+            destination=destination, path=path, data={"user_ids": user_ids}
+        )
+
 
 def _create_path(federation_prefix: str, path: str, *args: str) -> str:
     """
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index db4fe2c798..67a6347907 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -24,6 +24,7 @@ from synapse.federation.transport.server._base import (
 )
 from synapse.federation.transport.server.federation import (
     FEDERATION_SERVLET_CLASSES,
+    FederationAccountStatusServlet,
     FederationTimestampLookupServlet,
 )
 from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES
@@ -336,6 +337,13 @@ def register_servlets(
             ):
                 continue
 
+            # Only allow the `/account_status` servlet if msc3720 is enabled
+            if (
+                servletclass == FederationAccountStatusServlet
+                and not hs.config.experimental.msc3720_enabled
+            ):
+                continue
+
             servletclass(
                 hs=hs,
                 authenticator=authenticator,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index dff2b68359..87e99c7ddf 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -55,7 +55,7 @@ class Authenticator:
         self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.federation_domain_whitelist = (
             hs.config.federation.federation_domain_whitelist
         )
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index e85a8eda5b..aed3d5069c 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -110,7 +110,7 @@ class FederationSendServlet(BaseFederationServerServlet):
             if issue_8631_logger.isEnabledFor(logging.DEBUG):
                 DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"]
                 device_list_updates = [
-                    edu.content
+                    edu.get("content", {})
                     for edu in transaction_data.get("edus", [])
                     if edu.get("edu_type") in DEVICE_UPDATE_EDUS
                 ]
@@ -624,81 +624,6 @@ class FederationVersionServlet(BaseFederationServlet):
         )
 
 
-class FederationSpaceSummaryServlet(BaseFederationServlet):
-    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946"
-    PATH = "/spaces/(?P<room_id>[^/]*)"
-
-    def __init__(
-        self,
-        hs: "HomeServer",
-        authenticator: Authenticator,
-        ratelimiter: FederationRateLimiter,
-        server_name: str,
-    ):
-        super().__init__(hs, authenticator, ratelimiter, server_name)
-        self.handler = hs.get_room_summary_handler()
-
-    async def on_GET(
-        self,
-        origin: str,
-        content: Literal[None],
-        query: Mapping[bytes, Sequence[bytes]],
-        room_id: str,
-    ) -> Tuple[int, JsonDict]:
-        suggested_only = parse_boolean_from_args(query, "suggested_only", default=False)
-
-        max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space")
-        if max_rooms_per_space is not None and max_rooms_per_space < 0:
-            raise SynapseError(
-                400,
-                "Value for 'max_rooms_per_space' must be a non-negative integer",
-                Codes.BAD_JSON,
-            )
-
-        exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[])
-
-        return 200, await self.handler.federation_space_summary(
-            origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
-        )
-
-    # TODO When switching to the stable endpoint, remove the POST handler.
-    async def on_POST(
-        self,
-        origin: str,
-        content: JsonDict,
-        query: Mapping[bytes, Sequence[bytes]],
-        room_id: str,
-    ) -> Tuple[int, JsonDict]:
-        suggested_only = content.get("suggested_only", False)
-        if not isinstance(suggested_only, bool):
-            raise SynapseError(
-                400, "'suggested_only' must be a boolean", Codes.BAD_JSON
-            )
-
-        exclude_rooms = content.get("exclude_rooms", [])
-        if not isinstance(exclude_rooms, list) or any(
-            not isinstance(x, str) for x in exclude_rooms
-        ):
-            raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON)
-
-        max_rooms_per_space = content.get("max_rooms_per_space")
-        if max_rooms_per_space is not None:
-            if not isinstance(max_rooms_per_space, int):
-                raise SynapseError(
-                    400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON
-                )
-            if max_rooms_per_space < 0:
-                raise SynapseError(
-                    400,
-                    "Value for 'max_rooms_per_space' must be a non-negative integer",
-                    Codes.BAD_JSON,
-                )
-
-        return 200, await self.handler.federation_space_summary(
-            origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
-        )
-
-
 class FederationRoomHierarchyServlet(BaseFederationServlet):
     PATH = "/hierarchy/(?P<room_id>[^/]*)"
 
@@ -746,7 +671,7 @@ class RoomComplexityServlet(BaseFederationServlet):
         server_name: str,
     ):
         super().__init__(hs, authenticator, ratelimiter, server_name)
-        self._store = self.hs.get_datastore()
+        self._store = self.hs.get_datastores().main
 
     async def on_GET(
         self,
@@ -766,6 +691,40 @@ class RoomComplexityServlet(BaseFederationServlet):
         return 200, complexity
 
 
+class FederationAccountStatusServlet(BaseFederationServerServlet):
+    PATH = "/query/account_status"
+    PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720"
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        authenticator: Authenticator,
+        ratelimiter: FederationRateLimiter,
+        server_name: str,
+    ):
+        super().__init__(hs, authenticator, ratelimiter, server_name)
+        self._account_handler = hs.get_account_handler()
+
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Mapping[bytes, Sequence[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
+        if "user_ids" not in content:
+            raise SynapseError(
+                400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
+            )
+
+        statuses, failures = await self._account_handler.get_account_statuses(
+            content["user_ids"],
+            allow_remote=False,
+        )
+
+        return 200, {"account_statuses": statuses, "failures": failures}
+
+
 FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     FederationSendServlet,
     FederationEventServlet,
@@ -792,9 +751,9 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
     On3pidBindServlet,
     FederationVersionServlet,
     RoomComplexityServlet,
-    FederationSpaceSummaryServlet,
     FederationRoomHierarchyServlet,
     FederationRoomHierarchyUnstableServlet,
     FederationV1SendKnockServlet,
     FederationMakeKnockServlet,
+    FederationAccountStatusServlet,
 )
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index a87896e538..ed26d6a6ce 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -140,7 +140,7 @@ class GroupAttestionRenewer:
 
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.assestations = hs.get_groups_attestation_signing()
         self.transport_client = hs.get_federation_transport_client()
         self.is_mine_id = hs.is_mine_id
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 449bbc7004..4c3a5a6e24 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -45,7 +45,7 @@ MAX_LONG_DESC_LEN = 10000
 class GroupsServerWorkerHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.room_list_handler = hs.get_room_list_handler()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
new file mode 100644
index 0000000000..d5badf635b
--- /dev/null
+++ b/synapse/handlers/account.py
@@ -0,0 +1,144 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, List, Tuple
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class AccountHandler:
+    def __init__(self, hs: "HomeServer"):
+        self._main_store = hs.get_datastores().main
+        self._is_mine = hs.is_mine
+        self._federation_client = hs.get_federation_client()
+
+    async def get_account_statuses(
+        self,
+        user_ids: List[str],
+        allow_remote: bool,
+    ) -> Tuple[JsonDict, List[str]]:
+        """Get account statuses for a list of user IDs.
+
+        If one or more account(s) belong to remote homeservers, retrieve their status(es)
+        over federation if allowed.
+
+        Args:
+            user_ids: The list of accounts to retrieve the status of.
+            allow_remote: Whether to try to retrieve the status of remote accounts, if
+                any.
+
+        Returns:
+            The account statuses as well as the list of users whose statuses could not be
+            retrieved.
+
+        Raises:
+            SynapseError if a required parameter is missing or malformed, or if one of
+            the accounts isn't local to this homeserver and allow_remote is False.
+        """
+        statuses = {}
+        failures = []
+        remote_users: List[UserID] = []
+
+        for raw_user_id in user_ids:
+            try:
+                user_id = UserID.from_string(raw_user_id)
+            except SynapseError:
+                raise SynapseError(
+                    400,
+                    f"Not a valid Matrix user ID: {raw_user_id}",
+                    Codes.INVALID_PARAM,
+                )
+
+            if self._is_mine(user_id):
+                status = await self._get_local_account_status(user_id)
+                statuses[user_id.to_string()] = status
+            else:
+                if not allow_remote:
+                    raise SynapseError(
+                        400,
+                        f"Not a local user: {raw_user_id}",
+                        Codes.INVALID_PARAM,
+                    )
+
+                remote_users.append(user_id)
+
+        if allow_remote and len(remote_users) > 0:
+            remote_statuses, remote_failures = await self._get_remote_account_statuses(
+                remote_users,
+            )
+
+            statuses.update(remote_statuses)
+            failures += remote_failures
+
+        return statuses, failures
+
+    async def _get_local_account_status(self, user_id: UserID) -> JsonDict:
+        """Retrieve the status of a local account.
+
+        Args:
+            user_id: The account to retrieve the status of.
+
+        Returns:
+            The account's status.
+        """
+        status = {"exists": False}
+
+        userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
+
+        if userinfo is not None:
+            status = {
+                "exists": True,
+                "deactivated": userinfo.is_deactivated,
+            }
+
+        return status
+
+    async def _get_remote_account_statuses(
+        self, remote_users: List[UserID]
+    ) -> Tuple[JsonDict, List[str]]:
+        """Send out federation requests to retrieve the statuses of remote accounts.
+
+        Args:
+            remote_users: The accounts to retrieve the statuses of.
+
+        Returns:
+            The statuses of the accounts, and a list of accounts for which no status
+            could be retrieved.
+        """
+        # Group remote users by destination, so we only send one request per remote
+        # homeserver.
+        by_destination: Dict[str, List[str]] = {}
+        for user in remote_users:
+            if user.domain not in by_destination:
+                by_destination[user.domain] = []
+
+            by_destination[user.domain].append(user.to_string())
+
+        # Retrieve the statuses and failures for remote accounts.
+        final_statuses: JsonDict = {}
+        final_failures: List[str] = []
+        for destination, users in by_destination.items():
+            statuses, failures = await self._federation_client.get_account_status(
+                destination,
+                users,
+            )
+
+            final_statuses.update(statuses)
+            final_failures += failures
+
+        return final_statuses, final_failures
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index bad48713bc..177b4f8991 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
 
 class AccountDataHandler:
     def __init__(self, hs: "HomeServer"):
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._instance_name = hs.get_instance_name()
         self._notifier = hs.get_notifier()
 
@@ -166,7 +166,7 @@ class AccountDataHandler:
 
 class AccountDataEventSource(EventSource[int, JsonDict]):
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def get_current_key(self, direction: str = "f") -> int:
         return self.store.get_max_account_data_stream_id()
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 87e415df75..9d0975f636 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -43,7 +43,7 @@ class AccountValidityHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.config = hs.config
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.send_email_handler = self.hs.get_send_email_handler()
         self.clock = self.hs.get_clock()
 
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 00ab5e79bf..96376963f2 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 class AdminHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index a42c3558e4..e6461cc3c9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -47,7 +47,7 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed
 
 class ApplicationServicesHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.is_mine_id = hs.is_mine_id
         self.appservice_api = hs.get_application_service_api()
         self.scheduler = hs.get_application_service_scheduler()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 572f54b1e3..3e29c96a49 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -194,7 +194,7 @@ class AuthHandler:
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
@@ -1183,7 +1183,7 @@ class AuthHandler:
 
             # No password providers were able to handle this 3pid
             # Check local store
-            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+            user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
                 medium, address
             )
             if not user_id:
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 5d8f6c50a9..7163af8004 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -61,7 +61,7 @@ class CasHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 7a13d76a68..76ae768e6e 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -29,7 +29,7 @@ class DeactivateAccountHandler:
     """Handler which deals with deactivating user accounts."""
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.hs = hs
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
@@ -38,6 +38,7 @@ class DeactivateAccountHandler:
         self._profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
         self._server_name = hs.hostname
+        self._third_party_rules = hs.get_third_party_event_rules()
 
         # Flag that indicates whether the process to part users from rooms is running
         self._user_parter_running = False
@@ -135,9 +136,13 @@ class DeactivateAccountHandler:
         if erase_data:
             user = UserID.from_string(user_id)
             # Remove avatar URL from this user
-            await self._profile_handler.set_avatar_url(user, requester, "", by_admin)
+            await self._profile_handler.set_avatar_url(
+                user, requester, "", by_admin, deactivation=True
+            )
             # Remove displayname from this user
-            await self._profile_handler.set_displayname(user, requester, "", by_admin)
+            await self._profile_handler.set_displayname(
+                user, requester, "", by_admin, deactivation=True
+            )
 
             logger.info("Marking %s as erased", user_id)
             await self.store.mark_user_erased(user_id)
@@ -160,6 +165,13 @@ class DeactivateAccountHandler:
         # Remove account data (including ignored users and push rules).
         await self.store.purge_account_data_for_user(user_id)
 
+        # Let modules know the user has been deactivated.
+        await self._third_party_rules.on_user_deactivation_status_changed(
+            user_id,
+            True,
+            by_admin,
+        )
+
         return identity_server_supports_unbinding
 
     async def _reject_pending_invites_for_user(self, user_id: str) -> None:
@@ -264,6 +276,10 @@ class DeactivateAccountHandler:
         # Mark the user as active.
         await self.store.set_user_deactivated_status(user_id, False)
 
+        await self._third_party_rules.on_user_deactivation_status_changed(
+            user_id, False, True
+        )
+
         # Add the user to the directory, if necessary. Note that
         # this must be done after the user is re-activated, because
         # deactivated users are excluded from the user directory.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 36c05f8363..934b5bd734 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -63,7 +63,7 @@ class DeviceWorkerHandler:
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.state = hs.get_state_handler()
         self.state_store = hs.get_storage().state
@@ -628,7 +628,7 @@ class DeviceListUpdater:
     "Handles incoming device list updates from federation and updates the DB"
 
     def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
         self.device_handler = device_handler
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index b582266af9..4cb725d027 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -43,7 +43,7 @@ class DeviceMessageHandler:
         Args:
             hs: server
         """
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.is_mine = hs.is_mine
 
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 082f521791..b7064c6624 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -44,7 +44,7 @@ class DirectoryHandler:
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.config = hs.config
         self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
         self.require_membership = hs.config.server.require_membership_for_aliases
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d4dfddf63f..d96456cd40 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
 
 class E2eKeysHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
         self.is_mine = hs.is_mine
@@ -1335,7 +1335,7 @@ class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
     def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
         self.e2e_keys_handler = e2e_keys_handler
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 12614b2c5d..52e44a2d42 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -45,7 +45,7 @@ class E2eRoomKeysHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         # Used to lock whenever a client is uploading key data.  This prevents collisions
         # between clients trying to upload the details of a new session, given all
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 365063ebdf..d441ebb0ab 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -43,7 +43,7 @@ class EventAuthHandler:
 
     def __init__(self, hs: "HomeServer"):
         self._clock = hs.get_clock()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._server_name = hs.hostname
 
     async def check_auth_rules_from_context(
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index bac5de0526..97e75e60c3 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
 
 class EventStreamHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.hs = hs
 
@@ -134,7 +134,7 @@ class EventStreamHandler:
 
 class EventHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
 
     async def get_event(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e9ac920bcc..eb03a5accb 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -107,7 +107,7 @@ class FederationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
         self.federation_client = hs.get_federation_client()
@@ -519,8 +519,17 @@ class FederationHandler:
                 state_events=state,
             )
 
+            if ret.partial_state:
+                await self.store.store_partial_state_room(room_id, ret.servers_in_room)
+
             max_stream_id = await self._federation_event_handler.process_remote_join(
-                origin, room_id, auth_chain, state, event, room_version_obj
+                origin,
+                room_id,
+                auth_chain,
+                state,
+                event,
+                room_version_obj,
+                partial_state=ret.partial_state,
             )
 
             # We wait here until this instance has seen the events come down
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 7683246bef..4bd87709f3 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -95,7 +95,7 @@ class FederationEventHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._storage = hs.get_storage()
         self._state_store = self._storage.state
 
@@ -397,6 +397,7 @@ class FederationEventHandler:
         state: List[EventBase],
         event: EventBase,
         room_version: RoomVersion,
+        partial_state: bool,
     ) -> int:
         """Persists the events returned by a send_join
 
@@ -412,6 +413,7 @@ class FederationEventHandler:
             event
             room_version: The room version we expect this room to have, and
                 will raise if it doesn't match the version in the create event.
+            partial_state: True if the state omits non-critical membership events
 
         Returns:
             The stream ID after which all events have been persisted.
@@ -453,10 +455,14 @@ class FederationEventHandler:
         )
 
         # and now persist the join event itself.
-        logger.info("Peristing join-via-remote %s", event)
+        logger.info(
+            "Peristing join-via-remote %s (partial_state: %s)", event, partial_state
+        )
         with nested_logging_context(suffix=event.event_id):
             context = await self._state_handler.compute_event_context(
-                event, old_state=state
+                event,
+                old_state=state,
+                partial_state=partial_state,
             )
 
             context = await self._check_event_auth(origin, event, context)
@@ -698,6 +704,8 @@ class FederationEventHandler:
 
         try:
             state = await self._resolve_state_at_missing_prevs(origin, event)
+            # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
+            #   not return partial state
             await self._process_received_pdu(
                 origin, event, state=state, backfilled=backfilled
             )
@@ -1791,6 +1799,7 @@ class FederationEventHandler:
             prev_state_ids=prev_state_ids,
             prev_group=prev_group,
             delta_ids=state_updates,
+            partial_state=context.partial_state,
         )
 
     async def _run_push_actions_and_persist_event(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 9e270d461b..e7a399787b 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -63,7 +63,7 @@ def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
 class GroupsLocalWorkerHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.room_list_handler = hs.get_room_list_handler()
         self.groups_server_handler = hs.get_groups_server_handler()
         self.transport_client = hs.get_federation_transport_client()
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c83eaea359..57c9fdfe62 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -49,7 +49,7 @@ id_server_scheme = "https://"
 
 class IdentityHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         # An HTTP client for contacting trusted URLs.
         self.http_client = SimpleHttpClient(hs)
         # An HTTP client for contacting identity servers specified by clients.
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 346a06ff49..344f20f37c 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
 
 class InitialSyncHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.state_handler = hs.get_state_handler()
         self.hs = hs
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 3963a7ac6e..9f06238d99 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -55,8 +55,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
-from synapse.util import json_decoder, json_encoder, log_failure
-from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
+from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, gather_results
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
@@ -75,7 +75,7 @@ class MessageHandler:
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
         self._event_serializer = hs.get_event_client_serializer()
@@ -397,7 +397,7 @@ class EventCreationHandler:
         self.hs = hs
         self.auth = hs.get_auth()
         self._event_auth_handler = hs.get_event_auth_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
@@ -992,6 +992,8 @@ class EventCreationHandler:
             and full_state_ids_at_event
             and builder.internal_metadata.is_historical()
         ):
+            # TODO(faster_joins): figure out how this works, and make sure that the
+            #   old state is complete.
             old_state = await self.store.get_events_as_list(full_state_ids_at_event)
             context = await self.state.compute_event_context(event, old_state=old_state)
         else:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 8f71d975e9..593a2aac66 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -273,7 +273,7 @@ class OidcProvider:
         token_generator: "OidcSessionTokenGenerator",
         provider: OidcProviderConfig,
     ):
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
         self._token_generator = token_generator
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 973f262964..5c01a426ff 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -127,7 +127,7 @@ class PaginationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
         self.clock = hs.get_clock()
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index b223b72623..c155098bee 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -133,7 +133,7 @@ class BasePresenceHandler(abc.ABC):
 
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.presence_router = hs.get_presence_router()
         self.state = hs.get_state_handler()
         self.is_mine_id = hs.is_mine_id
@@ -1541,7 +1541,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
         self.get_presence_handler = hs.get_presence_handler
         self.get_presence_router = hs.get_presence_router
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def get_new_events(
         self,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 36e3ad2ba9..6554c0d3c2 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -54,7 +54,7 @@ class ProfileHandler:
     PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.hs = hs
 
@@ -71,6 +71,8 @@ class ProfileHandler:
 
         self.server_name = hs.config.server.server_name
 
+        self._third_party_rules = hs.get_third_party_event_rules()
+
         if hs.config.worker.run_background_tasks:
             self.clock.looping_call(
                 self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
@@ -171,6 +173,7 @@ class ProfileHandler:
         requester: Requester,
         new_displayname: str,
         by_admin: bool = False,
+        deactivation: bool = False,
     ) -> None:
         """Set the displayname of a user
 
@@ -179,6 +182,7 @@ class ProfileHandler:
             requester: The user attempting to make this change.
             new_displayname: The displayname to give this user.
             by_admin: Whether this change was made by an administrator.
+            deactivation: Whether this change was made while deactivating the user.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -227,6 +231,10 @@ class ProfileHandler:
             target_user.to_string(), profile
         )
 
+        await self._third_party_rules.on_profile_update(
+            target_user.to_string(), profile, by_admin, deactivation
+        )
+
         await self._update_join_states(requester, target_user)
 
     async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
@@ -261,6 +269,7 @@ class ProfileHandler:
         requester: Requester,
         new_avatar_url: str,
         by_admin: bool = False,
+        deactivation: bool = False,
     ) -> None:
         """Set a new avatar URL for a user.
 
@@ -269,6 +278,7 @@ class ProfileHandler:
             requester: The user attempting to make this change.
             new_avatar_url: The avatar URL to give this user.
             by_admin: Whether this change was made by an administrator.
+            deactivation: Whether this change was made while deactivating the user.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -315,6 +325,10 @@ class ProfileHandler:
             target_user.to_string(), profile
         )
 
+        await self._third_party_rules.on_profile_update(
+            target_user.to_string(), profile, by_admin, deactivation
+        )
+
         await self._update_join_states(requester, target_user)
 
     @cached()
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 58593e570e..bad1acc634 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 class ReadMarkerHandler:
     def __init__(self, hs: "HomeServer"):
         self.server_name = hs.config.server.server_name
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.account_data_handler = hs.get_account_data_handler()
         self.read_marker_linearizer = Linearizer(name="read_marker")
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 5cb1ff749d..b4132c353a 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -29,7 +29,7 @@ class ReceiptsHandler:
     def __init__(self, hs: "HomeServer"):
         self.notifier = hs.get_notifier()
         self.server_name = hs.config.server.server_name
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.event_auth_handler = hs.get_event_auth_handler()
 
         self.hs = hs
@@ -163,7 +163,7 @@ class ReceiptsHandler:
 
 class ReceiptEventSource(EventSource[int, JsonDict]):
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.config = hs.config
 
     @staticmethod
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 80320d2c07..05bb1e0225 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -86,7 +86,7 @@ class LoginDict(TypedDict):
 
 class RegistrationHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.hs = hs
         self.auth = hs.get_auth()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a990727fc5..7b965b4b96 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -105,7 +105,7 @@ class EventContext:
 
 class RoomCreationHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.hs = hs
@@ -1115,7 +1115,7 @@ class RoomContextHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
@@ -1246,7 +1246,7 @@ class RoomContextHandler:
 class TimestampLookupHandler:
     def __init__(self, hs: "HomeServer"):
         self.server_name = hs.hostname
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state_handler = hs.get_state_handler()
         self.federation_client = hs.get_federation_client()
 
@@ -1386,7 +1386,7 @@ class TimestampLookupHandler:
 
 class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def get_new_events(
         self,
@@ -1476,7 +1476,7 @@ class RoomShutdownHandler:
         self._room_creation_handler = hs.get_room_creation_handler()
         self._replication = hs.get_replication_data_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def shutdown_room(
         self,
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index f8137ec04c..abbf7b7b27 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
 class RoomBatchHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state_store = hs.get_storage().state
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 1a33211a1f..f3577b5d5a 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -49,7 +49,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 
 class RoomListHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.hs = hs
         self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
         self.response_cache: ResponseCache[
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 02f0389c0e..91722a70f4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -66,7 +66,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.state_handler = hs.get_state_handler()
         self.config = hs.config
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 4844b69a03..55c2cbdba8 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -15,7 +15,6 @@
 import itertools
 import logging
 import re
-from collections import deque
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple
 
 import attr
@@ -90,7 +89,7 @@ class RoomSummaryHandler:
 
     def __init__(self, hs: "HomeServer"):
         self._event_auth_handler = hs.get_event_auth_handler()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._event_serializer = hs.get_event_client_serializer()
         self._server_name = hs.hostname
         self._federation_client = hs.get_federation_client()
@@ -107,153 +106,6 @@ class RoomSummaryHandler:
             "get_room_hierarchy",
         )
 
-    async def get_space_summary(
-        self,
-        requester: str,
-        room_id: str,
-        suggested_only: bool = False,
-        max_rooms_per_space: Optional[int] = None,
-    ) -> JsonDict:
-        """
-        Implementation of the space summary C-S API
-
-        Args:
-            requester:  user id of the user making this request
-
-            room_id: room id to start the summary at
-
-            suggested_only: whether we should only return children with the "suggested"
-                flag set.
-
-            max_rooms_per_space: an optional limit on the number of child rooms we will
-                return. This does not apply to the root room (ie, room_id), and
-                is overridden by MAX_ROOMS_PER_SPACE.
-
-        Returns:
-            summary dict to return
-        """
-        # First of all, check that the room is accessible.
-        if not await self._is_local_room_accessible(room_id, requester):
-            raise AuthError(
-                403,
-                "User %s not in room %s, and room previews are disabled"
-                % (requester, room_id),
-            )
-
-        # the queue of rooms to process
-        room_queue = deque((_RoomQueueEntry(room_id, ()),))
-
-        # rooms we have already processed
-        processed_rooms: Set[str] = set()
-
-        # events we have already processed. We don't necessarily have their event ids,
-        # so instead we key on (room id, state key)
-        processed_events: Set[Tuple[str, str]] = set()
-
-        rooms_result: List[JsonDict] = []
-        events_result: List[JsonDict] = []
-
-        if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE:
-            max_rooms_per_space = MAX_ROOMS_PER_SPACE
-
-        while room_queue and len(rooms_result) < MAX_ROOMS:
-            queue_entry = room_queue.popleft()
-            room_id = queue_entry.room_id
-            if room_id in processed_rooms:
-                # already done this room
-                continue
-
-            logger.debug("Processing room %s", room_id)
-
-            is_in_room = await self._store.is_host_joined(room_id, self._server_name)
-
-            # The client-specified max_rooms_per_space limit doesn't apply to the
-            # room_id specified in the request, so we ignore it if this is the
-            # first room we are processing.
-            max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS
-
-            if is_in_room:
-                room_entry = await self._summarize_local_room(
-                    requester, None, room_id, suggested_only, max_children
-                )
-
-                events: Sequence[JsonDict] = []
-                if room_entry:
-                    rooms_result.append(room_entry.room)
-                    events = room_entry.children_state_events
-
-                logger.debug(
-                    "Query of local room %s returned events %s",
-                    room_id,
-                    ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
-                )
-            else:
-                fed_rooms = await self._summarize_remote_room(
-                    queue_entry,
-                    suggested_only,
-                    max_children,
-                    exclude_rooms=processed_rooms,
-                )
-
-                # The results over federation might include rooms that the we,
-                # as the requesting server, are allowed to see, but the requesting
-                # user is not permitted see.
-                #
-                # Filter the returned results to only what is accessible to the user.
-                events = []
-                for room_entry in fed_rooms:
-                    room = room_entry.room
-                    fed_room_id = room_entry.room_id
-
-                    # The user can see the room, include it!
-                    if await self._is_remote_room_accessible(
-                        requester, fed_room_id, room
-                    ):
-                        # Before returning to the client, remove the allowed_room_ids
-                        # and allowed_spaces keys.
-                        room.pop("allowed_room_ids", None)
-                        room.pop("allowed_spaces", None)  # historical
-
-                        rooms_result.append(room)
-                        events.extend(room_entry.children_state_events)
-
-                    # All rooms returned don't need visiting again (even if the user
-                    # didn't have access to them).
-                    processed_rooms.add(fed_room_id)
-
-                logger.debug(
-                    "Query of %s returned rooms %s, events %s",
-                    room_id,
-                    [room_entry.room.get("room_id") for room_entry in fed_rooms],
-                    ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events],
-                )
-
-            # the room we queried may or may not have been returned, but don't process
-            # it again, anyway.
-            processed_rooms.add(room_id)
-
-            # XXX: is it ok that we blindly iterate through any events returned by
-            #   a remote server, whether or not they actually link to any rooms in our
-            #   tree?
-            for ev in events:
-                # remote servers might return events we have already processed
-                # (eg, Dendrite returns inward pointers as well as outward ones), so
-                # we need to filter them out, to avoid returning duplicate links to the
-                # client.
-                ev_key = (ev["room_id"], ev["state_key"])
-                if ev_key in processed_events:
-                    continue
-                events_result.append(ev)
-
-                # add the child to the queue. we have already validated
-                # that the vias are a list of server names.
-                room_queue.append(
-                    _RoomQueueEntry(ev["state_key"], ev["content"]["via"])
-                )
-                processed_events.add(ev_key)
-
-        return {"rooms": rooms_result, "events": events_result}
-
     async def get_room_hierarchy(
         self,
         requester: Requester,
@@ -398,8 +250,6 @@ class RoomSummaryHandler:
                     None,
                     room_id,
                     suggested_only,
-                    # Do not limit the maximum children.
-                    max_children=None,
                 )
 
             # Otherwise, attempt to use information for federation.
@@ -488,74 +338,6 @@ class RoomSummaryHandler:
 
         return result
 
-    async def federation_space_summary(
-        self,
-        origin: str,
-        room_id: str,
-        suggested_only: bool,
-        max_rooms_per_space: Optional[int],
-        exclude_rooms: Iterable[str],
-    ) -> JsonDict:
-        """
-        Implementation of the space summary Federation API
-
-        Args:
-            origin: The server requesting the spaces summary.
-
-            room_id: room id to start the summary at
-
-            suggested_only: whether we should only return children with the "suggested"
-                flag set.
-
-            max_rooms_per_space: an optional limit on the number of child rooms we will
-                return. Unlike the C-S API, this applies to the root room (room_id).
-                It is clipped to MAX_ROOMS_PER_SPACE.
-
-            exclude_rooms: a list of rooms to skip over (presumably because the
-                calling server has already seen them).
-
-        Returns:
-            summary dict to return
-        """
-        # the queue of rooms to process
-        room_queue = deque((room_id,))
-
-        # the set of rooms that we should not walk further. Initialise it with the
-        # excluded-rooms list; we will add other rooms as we process them so that
-        # we do not loop.
-        processed_rooms: Set[str] = set(exclude_rooms)
-
-        rooms_result: List[JsonDict] = []
-        events_result: List[JsonDict] = []
-
-        # Set a limit on the number of rooms to return.
-        if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE:
-            max_rooms_per_space = MAX_ROOMS_PER_SPACE
-
-        while room_queue and len(rooms_result) < MAX_ROOMS:
-            room_id = room_queue.popleft()
-            if room_id in processed_rooms:
-                # already done this room
-                continue
-
-            room_entry = await self._summarize_local_room(
-                None, origin, room_id, suggested_only, max_rooms_per_space
-            )
-
-            processed_rooms.add(room_id)
-
-            if room_entry:
-                rooms_result.append(room_entry.room)
-                events_result.extend(room_entry.children_state_events)
-
-                # add any children to the queue
-                room_queue.extend(
-                    edge_event["state_key"]
-                    for edge_event in room_entry.children_state_events
-                )
-
-        return {"rooms": rooms_result, "events": events_result}
-
     async def get_federation_hierarchy(
         self,
         origin: str,
@@ -579,7 +361,7 @@ class RoomSummaryHandler:
             The JSON hierarchy dictionary.
         """
         root_room_entry = await self._summarize_local_room(
-            None, origin, requested_room_id, suggested_only, max_children=None
+            None, origin, requested_room_id, suggested_only
         )
         if root_room_entry is None:
             # Room is inaccessible to the requesting server.
@@ -600,7 +382,7 @@ class RoomSummaryHandler:
                 continue
 
             room_entry = await self._summarize_local_room(
-                None, origin, room_id, suggested_only, max_children=0
+                None, origin, room_id, suggested_only, include_children=False
             )
             # If the room is accessible, include it in the results.
             #
@@ -626,7 +408,7 @@ class RoomSummaryHandler:
         origin: Optional[str],
         room_id: str,
         suggested_only: bool,
-        max_children: Optional[int],
+        include_children: bool = True,
     ) -> Optional["_RoomEntry"]:
         """
         Generate a room entry and a list of event entries for a given room.
@@ -641,9 +423,8 @@ class RoomSummaryHandler:
             room_id: The room ID to summarize.
             suggested_only: True if only suggested children should be returned.
                 Otherwise, all children are returned.
-            max_children:
-                The maximum number of children rooms to include. A value of None
-                means no limit.
+            include_children:
+                Whether to include the events of any children.
 
         Returns:
             A room entry if the room should be returned. None, otherwise.
@@ -653,9 +434,8 @@ class RoomSummaryHandler:
 
         room_entry = await self._build_room_entry(room_id, for_federation=bool(origin))
 
-        # If the room is not a space or the children don't matter, return just
-        # the room information.
-        if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0:
+        # If the room is not a space return just the room information.
+        if room_entry.get("room_type") != RoomTypes.SPACE or not include_children:
             return _RoomEntry(room_id, room_entry)
 
         # Otherwise, look for child rooms/spaces.
@@ -665,14 +445,6 @@ class RoomSummaryHandler:
             # we only care about suggested children
             child_events = filter(_is_suggested_child_event, child_events)
 
-        # TODO max_children is legacy code for the /spaces endpoint.
-        if max_children is not None:
-            child_iter: Iterable[EventBase] = itertools.islice(
-                child_events, max_children
-            )
-        else:
-            child_iter = child_events
-
         stripped_events: List[JsonDict] = [
             {
                 "type": e.type,
@@ -682,80 +454,10 @@ class RoomSummaryHandler:
                 "sender": e.sender,
                 "origin_server_ts": e.origin_server_ts,
             }
-            for e in child_iter
+            for e in child_events
         ]
         return _RoomEntry(room_id, room_entry, stripped_events)
 
-    async def _summarize_remote_room(
-        self,
-        room: "_RoomQueueEntry",
-        suggested_only: bool,
-        max_children: Optional[int],
-        exclude_rooms: Iterable[str],
-    ) -> Iterable["_RoomEntry"]:
-        """
-        Request room entries and a list of event entries for a given room by querying a remote server.
-
-        Args:
-            room: The room to summarize.
-            suggested_only: True if only suggested children should be returned.
-                Otherwise, all children are returned.
-            max_children:
-                The maximum number of children rooms to include. This is capped
-                to a server-set limit.
-            exclude_rooms:
-                Rooms IDs which do not need to be summarized.
-
-        Returns:
-            An iterable of room entries.
-        """
-        room_id = room.room_id
-        logger.info("Requesting summary for %s via %s", room_id, room.via)
-
-        # we need to make the exclusion list json-serialisable
-        exclude_rooms = list(exclude_rooms)
-
-        via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE)
-        try:
-            res = await self._federation_client.get_space_summary(
-                via,
-                room_id,
-                suggested_only=suggested_only,
-                max_rooms_per_space=max_children,
-                exclude_rooms=exclude_rooms,
-            )
-        except Exception as e:
-            logger.warning(
-                "Unable to get summary of %s via federation: %s",
-                room_id,
-                e,
-                exc_info=logger.isEnabledFor(logging.DEBUG),
-            )
-            return ()
-
-        # Group the events by their room.
-        children_by_room: Dict[str, List[JsonDict]] = {}
-        for ev in res.events:
-            if ev.event_type == EventTypes.SpaceChild:
-                children_by_room.setdefault(ev.room_id, []).append(ev.data)
-
-        # Generate the final results.
-        results = []
-        for fed_room in res.rooms:
-            fed_room_id = fed_room.get("room_id")
-            if not fed_room_id or not isinstance(fed_room_id, str):
-                continue
-
-            results.append(
-                _RoomEntry(
-                    fed_room_id,
-                    fed_room,
-                    children_by_room.get(fed_room_id, []),
-                )
-            )
-
-        return results
-
     async def _summarize_remote_room_hierarchy(
         self, room: "_RoomQueueEntry", suggested_only: bool
     ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]:
@@ -958,9 +660,8 @@ class RoomSummaryHandler:
         ):
             return True
 
-        # Check if the user is a member of any of the allowed spaces
-        # from the response.
-        allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces")
+        # Check if the user is a member of any of the allowed rooms from the response.
+        allowed_rooms = room.get("allowed_room_ids")
         if allowed_rooms and isinstance(allowed_rooms, list):
             if await self._event_auth_handler.is_user_in_rooms(
                 allowed_rooms, requester
@@ -1028,8 +729,6 @@ class RoomSummaryHandler:
                 )
                 if allowed_rooms:
                     entry["allowed_room_ids"] = allowed_rooms
-                    # TODO Remove this key once the API is stable.
-                    entry["allowed_spaces"] = allowed_rooms
 
         # Filter out Nones – rather omit the field altogether
         room_entry = {k: v for k, v in entry.items() if v is not None}
@@ -1094,7 +793,7 @@ class RoomSummaryHandler:
                 room_id,
                 # Suggested-only doesn't matter since no children are requested.
                 suggested_only=False,
-                max_children=0,
+                include_children=False,
             )
 
             if not room_entry:
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 727d75a50c..9602f0d0bb 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -52,7 +52,7 @@ class Saml2SessionData:
 
 class SamlHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 0e0e58de02..aa16e417eb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -49,7 +49,7 @@ class _SearchResult:
 
 class SearchHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state_handler = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.hs = hs
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 706ad72761..73861bbd40 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -27,7 +27,7 @@ class SetPasswordHandler:
     """Handler which deals with changing user account passwords"""
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 0bb8b0929e..ff5b5169ca 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -180,7 +180,7 @@ class SsoHandler:
 
     def __init__(self, hs: "HomeServer"):
         self._clock = hs.get_clock()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
         self._auth_handler = hs.get_auth_handler()
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index d30ba2b724..2d197282ed 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -30,7 +30,7 @@ class MatchChange(Enum):
 
 class StateDeltasHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def _get_key_change(
         self,
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 29e41a4c79..436cd971ce 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -39,7 +39,7 @@ class StatsHandler:
 
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state = hs.get_state_handler()
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e6050cbce6..0aa3052fd6 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -266,7 +266,7 @@ class SyncResult:
 class SyncHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs_config = hs.config
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.presence_handler = hs.get_presence_handler()
         self.event_sources = hs.get_event_sources()
@@ -697,6 +697,15 @@ class SyncHandler:
         else:
             # no events in this room - so presumably no state
             state = {}
+
+            # (erikj) This should be rarely hit, but we've had some reports that
+            # we get more state down gappy syncs than we should, so let's add
+            # some logging.
+            logger.info(
+                "Failed to find any events in room %s at %s",
+                room_id,
+                stream_position.room_key,
+            )
         return state
 
     async def compute_summary(
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e4bed1c937..843c68eb0f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -57,7 +57,7 @@ class FollowerTypingHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.server_name = hs.config.server.server_name
         self.clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
@@ -446,7 +446,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
 class TypingNotificationEventSource(EventSource[int, JsonDict]):
     def __init__(self, hs: "HomeServer"):
-        self._main_store = hs.get_datastore()
+        self._main_store = hs.get_datastores().main
         self.clock = hs.get_clock()
         # We can't call get_typing_handler here because there's a cycle:
         #
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 184730ebe8..014754a630 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -139,7 +139,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
 class _BaseThreepidAuthChecker:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def _check_threepid(self, medium: str, authdict: dict) -> dict:
         if "threepid_creds" not in authdict:
@@ -255,7 +255,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
         super().__init__(hs)
         self.hs = hs
         self._enabled = bool(hs.config.registration.registration_requires_token)
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def is_enabled(self) -> bool:
         return self._enabled
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 1565e034cb..d27ed2be6a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -55,7 +55,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index e7656fbb9f..40bf1e06d6 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -351,7 +351,7 @@ class MatrixFederationHttpClient:
         )
 
         self.clock = hs.get_clock()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self.version_string_bytes = hs.version_string.encode("ascii")
         self.default_timeout = 60
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 07020bfb8d..7e46931869 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -145,6 +145,7 @@ __all__ = [
     "JsonDict",
     "EventBase",
     "StateMap",
+    "ProfileInfo",
 ]
 
 logger = logging.getLogger(__name__)
@@ -172,7 +173,9 @@ class ModuleApi:
 
         # TODO: Fix this type hint once the types for the data stores have been ironed
         #       out.
-        self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore()
+        self._store: Union[
+            DataStore, "GenericWorkerSlavedStore"
+        ] = hs.get_datastores().main
         self._auth = hs.get_auth()
         self._auth_handler = auth_handler
         self._server_name = hs.hostname
@@ -926,7 +929,7 @@ class ModuleApi:
         )
 
         # Try to retrieve the resulting event.
-        event = await self._hs.get_datastore().get_event(event_id)
+        event = await self._hs.get_datastores().main.get_event(event_id)
 
         # update_membership is supposed to always return after the event has been
         # successfully persisted.
@@ -1270,7 +1273,7 @@ class PublicRoomListManager:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
     async def room_is_in_public_room_list(self, room_id: str) -> bool:
         """Checks whether a room is in the public room list.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 753dd6b6a5..16d15a1f33 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -222,7 +222,7 @@ class Notifier:
         self.hs = hs
         self.storage = hs.get_storage()
         self.event_sources = hs.get_event_sources()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.pending_new_room_events: List[_PendingRoomEventEntry] = []
 
         # Called when there are new things to stream over replication
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5176a1c186..a1b7711098 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -68,7 +68,7 @@ class ThrottleParams:
 class Pusher(metaclass=abc.ABCMeta):
     def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
         self.hs = hs
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.clock = self.hs.get_clock()
 
         self.pusher_id = pusher_config.id
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index bee660893b..fecf86034e 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -103,7 +103,7 @@ class BulkPushRuleEvaluator:
 
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self._event_auth_handler = hs.get_event_auth_handler()
 
         # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
@@ -366,7 +366,7 @@ class RulesForRoom:
         """
         self.room_id = room_id
         self.is_mine_id = hs.is_mine_id
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
 
         # Used to ensure only one thing mutates the cache at a time. Keyed off
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 39bb2acae4..1710dd51b9 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -66,7 +66,7 @@ class EmailPusher(Pusher):
         super().__init__(hs, pusher_config)
         self.mailer = mailer
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.email = pusher_config.pushkey
         self.timed_call: Optional[IDelayedCall] = None
         self.throttle_params: Dict[str, ThrottleParams] = {}
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index bf40f596f6..f3c4419932 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -138,7 +138,7 @@ class HttpPusher(Pusher):
         # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
         # to be largely redundant. perhaps we can remove it.
         badge = await push_tools.get_badge_count(
-            self.hs.get_datastore(),
+            self.hs.get_datastores().main,
             self.user_id,
             group_by_room=self._group_unread_count_by_room,
         )
@@ -288,7 +288,7 @@ class HttpPusher(Pusher):
 
         tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
         badge = await push_tools.get_badge_count(
-            self.hs.get_datastore(),
+            self.hs.get_datastores().main,
             self.user_id,
             group_by_room=self._group_unread_count_by_room,
         )
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 3df8452eec..649a4f49d0 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -112,7 +112,7 @@ class Mailer:
         self.template_text = template_text
 
         self.send_email_handler = hs.get_send_email_handler()
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.state_store = self.hs.get_storage().state
         self.macaroon_gen = self.hs.get_macaroon_generator()
         self.state_handler = self.hs.get_state_handler()
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 659a53805d..f617c759e6 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -15,12 +15,12 @@
 
 import logging
 import re
-from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
+from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union
 
 from matrix_common.regex import glob_to_regex, to_word_pattern
 
 from synapse.events import EventBase
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -223,7 +223,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
 
 
 def _flatten_dict(
-    d: Union[EventBase, JsonDict],
+    d: Union[EventBase, Mapping[str, Any]],
     prefix: Optional[List[str]] = None,
     result: Optional[Dict[str, str]] = None,
 ) -> Dict[str, str]:
@@ -234,7 +234,7 @@ def _flatten_dict(
     for key, value in d.items():
         if isinstance(value, str):
             result[".".join(prefix + [key])] = value.lower()
-        elif isinstance(value, dict):
+        elif isinstance(value, Mapping):
             _flatten_dict(value, prefix=(prefix + [key]), result=result)
 
     return result
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 7912311d24..d0cc657b44 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -59,7 +59,7 @@ class PusherPool:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.pusher_factory = PusherFactory(hs)
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.clock = self.hs.get_clock()
 
         # We shard the handling of push notifications by user ID.
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index f43fbb5842..8f48a33936 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -17,14 +17,7 @@
 
 import itertools
 import logging
-from typing import List, Set
-
-from pkg_resources import (
-    DistributionNotFound,
-    Requirement,
-    VersionConflict,
-    get_provider,
-)
+from typing import Set
 
 logger = logging.getLogger(__name__)
 
@@ -90,6 +83,8 @@ REQUIREMENTS = [
     # ijson 3.1.4 fixes a bug with "." in property names
     "ijson>=3.1.4",
     "matrix-common~=1.1.0",
+    # For runtime introspection of our dependencies
+    "packaging~=21.3",
 ]
 
 CONDITIONAL_REQUIREMENTS = {
@@ -144,102 +139,6 @@ def list_requirements():
     return list(set(REQUIREMENTS) | ALL_OPTIONAL_REQUIREMENTS)
 
 
-class DependencyException(Exception):
-    @property
-    def message(self):
-        return "\n".join(
-            [
-                "Missing Requirements: %s" % (", ".join(self.dependencies),),
-                "To install run:",
-                "    pip install --upgrade --force %s" % (" ".join(self.dependencies),),
-                "",
-            ]
-        )
-
-    @property
-    def dependencies(self):
-        for i in self.args[0]:
-            yield '"' + i + '"'
-
-
-def check_requirements(for_feature=None):
-    deps_needed = []
-    errors = []
-
-    if for_feature:
-        reqs = CONDITIONAL_REQUIREMENTS[for_feature]
-    else:
-        reqs = REQUIREMENTS
-
-    for dependency in reqs:
-        try:
-            _check_requirement(dependency)
-        except VersionConflict as e:
-            deps_needed.append(dependency)
-            errors.append(
-                "Needed %s, got %s==%s"
-                % (
-                    dependency,
-                    e.dist.project_name,  # type: ignore[attr-defined] # noqa
-                    e.dist.version,  # type: ignore[attr-defined] # noqa
-                )
-            )
-        except DistributionNotFound:
-            deps_needed.append(dependency)
-            if for_feature:
-                errors.append(
-                    "Needed %s for the '%s' feature but it was not installed"
-                    % (dependency, for_feature)
-                )
-            else:
-                errors.append("Needed %s but it was not installed" % (dependency,))
-
-    if not for_feature:
-        # Check the optional dependencies are up to date. We allow them to not be
-        # installed.
-        OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), [])
-
-        for dependency in OPTS:
-            try:
-                _check_requirement(dependency)
-            except VersionConflict as e:
-                deps_needed.append(dependency)
-                errors.append(
-                    "Needed optional %s, got %s==%s"
-                    % (
-                        dependency,
-                        e.dist.project_name,  # type: ignore[attr-defined] # noqa
-                        e.dist.version,  # type: ignore[attr-defined] # noqa
-                    )
-                )
-            except DistributionNotFound:
-                # If it's not found, we don't care
-                pass
-
-    if deps_needed:
-        for err in errors:
-            logging.error(err)
-
-        raise DependencyException(deps_needed)
-
-
-def _check_requirement(dependency_string):
-    """Parses a dependency string, and checks if the specified requirement is installed
-
-    Raises:
-        VersionConflict if the requirement is installed, but with the the wrong version
-        DistributionNotFound if nothing is found to provide the requirement
-    """
-    req = Requirement.parse(dependency_string)
-
-    # first check if the markers specify that this requirement needs installing
-    if req.marker is not None and not req.marker.evaluate():
-        # not required for this environment
-        return
-
-    get_provider(req)
-
-
 if __name__ == "__main__":
     import sys
 
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index bc1d28dd19..2e697c74a6 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -268,7 +268,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                     raise e.to_synapse_error()
                 except Exception as e:
                     _outgoing_request_counter.labels(cls.NAME, "ERR").inc()
-                    raise SynapseError(502, "Failed to talk to main process") from e
+                    raise SynapseError(
+                        502, f"Failed to talk to {instance_name} process"
+                    ) from e
 
                 _outgoing_request_counter.labels(cls.NAME, 200).inc()
                 return result
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index f2f40129fe..3d63645726 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -63,7 +63,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
         super().__init__(hs)
 
         self.device_list_updater = hs.get_device_handler().device_list_updater
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
 
     @staticmethod
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index d529c8a19f..3e7300b4a1 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -68,7 +68,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.clock = hs.get_clock()
         self.federation_event_handler = hs.get_federation_event_handler()
@@ -167,7 +167,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.registry = hs.get_federation_registry()
 
@@ -214,7 +214,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.registry = hs.get_federation_registry()
 
@@ -260,7 +260,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     @staticmethod
     async def _serialize_payload(room_id: str) -> JsonDict:  # type: ignore[override]
@@ -297,7 +297,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     @staticmethod
     async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict:  # type: ignore[override]
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 0145858e47..663bff5738 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -50,7 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
         super().__init__(hs)
 
         self.federation_handler = hs.get_federation_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
 
     @staticmethod
@@ -119,7 +119,7 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
         super().__init__(hs)
 
         self.federation_handler = hs.get_federation_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
 
     @staticmethod
@@ -188,7 +188,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.member_handler = hs.get_room_member_handler()
 
@@ -258,7 +258,7 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.member_handler = hs.get_room_member_handler()
 
@@ -325,7 +325,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
         super().__init__(hs)
 
         self.registeration_handler = hs.get_registration_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.distributor = hs.get_distributor()
 
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index c7f751b70d..6c8f8388fd 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -36,7 +36,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
@@ -112,7 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 33e98daf8a..ce78176836 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -69,7 +69,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
         super().__init__(hs)
 
         self.event_creation_handler = hs.get_event_creation_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.clock = hs.get_clock()
 
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d59ce7ccf9..1b8479b0b4 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -111,7 +111,7 @@ class ReplicationDataHandler:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self._reactor = hs.get_reactor()
         self._clock = hs.get_clock()
@@ -340,7 +340,7 @@ class FederationSenderHandler:
     def __init__(self, hs: "HomeServer"):
         assert hs.should_send_federation()
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self._is_mine_id = hs.is_mine_id
         self._hs = hs
 
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 17e1572393..0d2013a3cf 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -95,7 +95,7 @@ class ReplicationCommandHandler:
     def __init__(self, hs: "HomeServer"):
         self._replication_data_handler = hs.get_replication_data_handler()
         self._presence_handler = hs.get_presence_handler()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._notifier = hs.get_notifier()
         self._clock = hs.get_clock()
         self._instance_id = hs.get_instance_id()
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ecd6190f5b..494e42a2be 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -72,7 +72,7 @@ class ReplicationStreamer:
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 914b9eae84..23d631a769 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -239,7 +239,7 @@ class BackfillStream(Stream):
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             self._current_token,
@@ -267,7 +267,7 @@ class PresenceStream(Stream):
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
 
         if hs.get_instance_name() in hs.config.worker.writers.presence:
             # on the presence writer, query the presence handler
@@ -355,7 +355,7 @@ class ReceiptsStream(Stream):
     ROW_TYPE = ReceiptsStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_max_receipt_stream_id),
@@ -374,7 +374,7 @@ class PushRulesStream(Stream):
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         super().__init__(
             hs.get_instance_name(),
@@ -401,7 +401,7 @@ class PushersStream(Stream):
     ROW_TYPE = PushersStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
 
         super().__init__(
             hs.get_instance_name(),
@@ -434,7 +434,7 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             store.get_cache_stream_token_for_writer,
@@ -455,7 +455,7 @@ class DeviceListsStream(Stream):
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_device_stream_token),
@@ -474,7 +474,7 @@ class ToDeviceStream(Stream):
     ROW_TYPE = ToDeviceStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_to_device_stream_token),
@@ -495,7 +495,7 @@ class TagAccountDataStream(Stream):
     ROW_TYPE = TagAccountDataStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_max_account_data_stream_id),
@@ -516,7 +516,7 @@ class AccountDataStream(Stream):
     ROW_TYPE = AccountDataStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(self.store.get_max_account_data_stream_id),
@@ -585,7 +585,7 @@ class GroupServerStream(Stream):
     ROW_TYPE = GroupsStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_group_stream_token),
@@ -604,7 +604,7 @@ class UserSignatureStream(Stream):
     ROW_TYPE = UserSignatureStreamRow
 
     def __init__(self, hs: "HomeServer"):
-        store = hs.get_datastore()
+        store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_device_stream_token),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 50c4a5ba03..26f4fa7cfd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -124,7 +124,7 @@ class EventsStream(Stream):
     NAME = "events"
 
     def __init__(self, hs: "HomeServer"):
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         super().__init__(
             hs.get_instance_name(),
             self._store._stream_id_gen.get_current_token_for_writer,
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index ba0d989d81..6de302f813 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -116,7 +116,7 @@ class PurgeHistoryRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.pagination_handler = hs.get_pagination_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_POST(
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index e9bce22a34..93a78db811 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -112,7 +112,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self._auth, request)
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index d9905ff560..cef46ba0dd 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -44,7 +44,7 @@ class DeviceRestServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.is_mine = hs.is_mine
 
     async def on_GET(
@@ -113,7 +113,7 @@ class DevicesRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.is_mine = hs.is_mine
 
     async def on_GET(
@@ -144,7 +144,7 @@ class DeleteDevicesRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.is_mine = hs.is_mine
 
     async def on_POST(
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 38477f8ead..6d634eef70 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -53,7 +53,7 @@ class EventReportsRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
@@ -115,7 +115,7 @@ class EventReportDetailRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, report_id: str
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index d162e0081e..023ed92144 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -48,7 +48,7 @@ class ListDestinationsRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self._auth, request)
@@ -105,7 +105,7 @@ class DestinationRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, destination: str
@@ -165,7 +165,7 @@ class DestinationMembershipRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, destination: str
@@ -221,7 +221,7 @@ class DestinationResetConnectionRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._authenticator = Authenticator(hs)
 
     async def on_POST(
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 299f5c9eb0..8ca57bdb28 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -47,7 +47,7 @@ class QuarantineMediaInRoom(RestServlet):
     ]
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_POST(
@@ -74,7 +74,7 @@ class QuarantineMediaByUser(RestServlet):
     PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_POST(
@@ -103,7 +103,7 @@ class QuarantineMediaByID(RestServlet):
     )
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_POST(
@@ -132,7 +132,7 @@ class UnquarantineMediaByID(RestServlet):
     )
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_POST(
@@ -156,7 +156,7 @@ class ProtectMediaByID(RestServlet):
     PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_POST(
@@ -178,7 +178,7 @@ class UnprotectMediaByID(RestServlet):
     PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_POST(
@@ -200,7 +200,7 @@ class ListMediaInRoom(RestServlet):
     PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_GET(
@@ -251,7 +251,7 @@ class DeleteMediaByID(RestServlet):
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
@@ -283,7 +283,7 @@ class DeleteMediaByDateSize(RestServlet):
     PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.server_name = hs.hostname
         self.media_repository = hs.get_media_repository()
@@ -352,7 +352,7 @@ class UserMediaRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.media_repository = hs.get_media_repository()
 
     async def on_GET(
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 04948b6408..af606e9252 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -71,7 +71,7 @@ class ListRegistrationTokensRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
@@ -109,7 +109,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         # A string of all the characters allowed to be in a registration_token
         self.allowed_chars = string.ascii_letters + string.digits + "._~-"
@@ -260,7 +260,7 @@ class RegistrationTokenRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]:
         """Retrieve a registration token."""
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 5b706efbcf..f4736a3dad 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -65,7 +65,7 @@ class RoomRestV2Servlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._pagination_handler = hs.get_pagination_handler()
 
     async def on_DELETE(
@@ -188,7 +188,7 @@ class ListRoomRestServlet(RestServlet):
     PATTERNS = admin_patterns("/rooms$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
@@ -278,7 +278,7 @@ class RoomRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
         self.pagination_handler = hs.get_pagination_handler()
 
@@ -382,7 +382,7 @@ class RoomMembersRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, room_id: str
@@ -408,7 +408,7 @@ class RoomStateRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self._event_serializer = hs.get_event_client_serializer()
 
@@ -525,7 +525,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.event_creation_handler = hs.get_event_creation_handler()
         self.state_handler = hs.get_state_handler()
         self.is_mine_id = hs.is_mine_id
@@ -670,7 +670,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_DELETE(
         self, request: SynapseRequest, room_identifier: str
@@ -781,7 +781,7 @@ class BlockRoomRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 7a6546372e..3b142b8402 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -38,7 +38,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index c2617ee30c..8e29ada8a0 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -66,7 +66,7 @@ class UsersRestServletV2(RestServlet):
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
@@ -156,7 +156,7 @@ class UserRestServletV2(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
         self.set_password_handler = hs.get_set_password_handler()
@@ -588,7 +588,7 @@ class DeactivateAccountRestServlet(RestServlet):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
         self.auth = hs.get_auth()
         self.is_mine = hs.is_mine
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_POST(
         self, request: SynapseRequest, target_user_id: str
@@ -674,7 +674,7 @@ class ResetPasswordRestServlet(RestServlet):
     PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self._set_password_handler = hs.get_set_password_handler()
@@ -717,7 +717,7 @@ class SearchUsersRestServlet(RestServlet):
     PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.is_mine = hs.is_mine
 
@@ -775,7 +775,7 @@ class UserAdminServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.is_mine = hs.is_mine
 
@@ -835,7 +835,7 @@ class UserMembershipRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -864,7 +864,7 @@ class PushersRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_GET(
@@ -905,7 +905,7 @@ class UserTokenRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.is_mine_id = hs.is_mine_id
@@ -974,7 +974,7 @@ class ShadowBanRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.is_mine_id = hs.is_mine_id
 
@@ -1026,7 +1026,7 @@ class RateLimitRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.is_mine_id = hs.is_mine_id
 
@@ -1129,7 +1129,7 @@ class AccountDataRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._is_mine_id = hs.is_mine_id
 
     async def on_GET(
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 9e38df81b4..d7821cbfa5 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -60,7 +60,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
-        self.datastore = hs.get_datastore()
+        self.datastore = hs.get_datastores().main
         self.config = hs.config
         self.identity_handler = hs.get_identity_handler()
 
@@ -114,7 +114,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
         # canonicalisation, matches the one in our database.
-        existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+        existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
             "email", email
         )
 
@@ -168,7 +168,7 @@ class PasswordRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
-        self.datastore = self.hs.get_datastore()
+        self.datastore = self.hs.get_datastores().main
         self.password_policy_handler = hs.get_password_policy_handler()
         self._set_password_handler = hs.get_set_password_handler()
 
@@ -347,7 +347,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         self.hs = hs
         self.config = hs.config
         self.identity_handler = hs.get_identity_handler()
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
             self.mailer = Mailer(
@@ -450,7 +450,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         super().__init__()
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.identity_handler = hs.get_identity_handler()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
@@ -536,7 +536,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
         super().__init__()
         self.config = hs.config
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
             self._failure_email_template = (
                 self.config.email.email_add_threepid_template_failure_html
@@ -603,7 +603,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
         super().__init__()
         self.config = hs.config
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.identity_handler = hs.get_identity_handler()
 
     async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
@@ -637,7 +637,7 @@ class ThreepidRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
-        self.datastore = self.hs.get_datastore()
+        self.datastore = self.hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
@@ -771,7 +771,7 @@ class ThreepidUnbindRestServlet(RestServlet):
         self.hs = hs
         self.identity_handler = hs.get_identity_handler()
         self.auth = hs.get_auth()
-        self.datastore = self.hs.get_datastore()
+        self.datastore = self.hs.get_datastores().main
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """Unbind the given 3pid from a specific identity server, or identity servers that are
@@ -899,6 +899,33 @@ class WhoamiRestServlet(RestServlet):
         return 200, response
 
 
+class AccountStatusRestServlet(RestServlet):
+    PATTERNS = client_patterns(
+        "/org.matrix.msc3720/account_status$", unstable=True, releases=()
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._auth = hs.get_auth()
+        self._account_handler = hs.get_account_handler()
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self._auth.get_user_by_req(request)
+
+        body = parse_json_object_from_request(request)
+        if "user_ids" not in body:
+            raise SynapseError(
+                400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
+            )
+
+        statuses, failures = await self._account_handler.get_account_statuses(
+            body["user_ids"],
+            allow_remote=True,
+        )
+
+        return 200, {"account_statuses": statuses, "failures": failures}
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     EmailPasswordRequestTokenRestServlet(hs).register(http_server)
     PasswordRestServlet(hs).register(http_server)
@@ -913,3 +940,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ThreepidUnbindRestServlet(hs).register(http_server)
     ThreepidDeleteRestServlet(hs).register(http_server)
     WhoamiRestServlet(hs).register(http_server)
+
+    if hs.config.experimental.msc3720_enabled:
+        AccountStatusRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index 58b8adbd32..bfe985939b 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -42,7 +42,7 @@ class AccountDataServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.handler = hs.get_account_data_handler()
 
     async def on_PUT(
@@ -90,7 +90,7 @@ class RoomAccountDataServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.handler = hs.get_account_data_handler()
 
     async def on_PUT(
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index e05c926b6f..4237071c61 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -72,8 +72,10 @@ class CapabilitiesRestServlet(RestServlet):
                 "org.matrix.msc3244.room_capabilities"
             ] = MSC3244_CAPABILITIES
 
-        if self.config.experimental.msc3440_enabled:
-            response["capabilities"]["io.element.thread"] = {"enabled": True}
+        if self.config.experimental.msc3720_enabled:
+            response["capabilities"]["org.matrix.msc3720.account_status"] = {
+                "enabled": True,
+            }
 
         return HTTPStatus.OK, response
 
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index ee247e3d1e..e181a0dde2 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -47,7 +47,7 @@ class ClientDirectoryServer(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.directory_handler = hs.get_directory_handler()
         self.auth = hs.get_auth()
 
@@ -129,7 +129,7 @@ class ClientDirectoryListServer(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.directory_handler = hs.get_directory_handler()
         self.auth = hs.get_auth()
 
@@ -173,7 +173,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.directory_handler = hs.get_directory_handler()
         self.auth = hs.get_auth()
 
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 672c821061..916f5230f1 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -39,7 +39,7 @@ class EventStreamRestServlet(RestServlet):
         super().__init__()
         self.event_stream_handler = hs.get_event_stream_handler()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py
index a7e9aa3e9b..7e1149c7f4 100644
--- a/synapse/rest/client/groups.py
+++ b/synapse/rest/client/groups.py
@@ -705,7 +705,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.is_mine_id = hs.is_mine_id
 
     @_validate_group_id
@@ -854,7 +854,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     @_validate_group_id
     async def on_PUT(
@@ -879,7 +879,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.groups_handler = hs.get_groups_local_handler()
 
     async def on_GET(
@@ -901,7 +901,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.groups_handler = hs.get_groups_local_handler()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index 49b1037b28..cfadcb8e50 100644
--- a/synapse/rest/client/initial_sync.py
+++ b/synapse/rest/client/initial_sync.py
@@ -33,7 +33,7 @@ class InitialSyncRestServlet(RestServlet):
         super().__init__()
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 730c18f08f..ce806e3c11 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -198,7 +198,7 @@ class KeyChangesServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index f9994658c4..c9d44c5964 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -104,13 +104,13 @@ class LoginRestServlet(RestServlet):
 
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
-            store=hs.get_datastore(),
+            store=hs.get_datastores().main,
             clock=hs.get_clock(),
             rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
             burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
         )
         self._account_ratelimiter = Ratelimiter(
-            store=hs.get_datastore(),
+            store=hs.get_datastores().main,
             clock=hs.get_clock(),
             rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
             burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 8e427a96a3..20377a9ac6 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -35,7 +35,7 @@ class NotificationsServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self._event_serializer = hs.get_event_client_serializer()
diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py
index add56d6998..820682ec42 100644
--- a/synapse/rest/client/openid.py
+++ b/synapse/rest/client/openid.py
@@ -67,7 +67,7 @@ class IdTokenServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.server_name = hs.config.server.server_name
 
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 8fe75bd750..a93f6fd5e0 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -57,7 +57,7 @@ class PushRuleRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self._is_worker = hs.config.worker.worker_app is not None
 
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 98604a9388..d6487c31dd 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -46,7 +46,9 @@ class PushersRestServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         user = requester.user
 
-        pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
+        pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(
+            user.to_string()
+        )
 
         filtered_pushers = [p.as_dict() for p in pushers]
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index b8a5135e02..70baf50fa4 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -123,7 +123,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
             request, "email", email
         )
 
-        existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+        existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
             "email", email
         )
 
@@ -203,7 +203,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
             request, "msisdn", msisdn
         )
 
-        existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+        existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
             "msisdn", msisdn
         )
 
@@ -258,7 +258,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
         self.auth = hs.get_auth()
         self.config = hs.config
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
             self._failure_email_template = (
@@ -385,7 +385,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.ratelimiter = Ratelimiter(
             store=self.store,
             clock=hs.get_clock(),
@@ -415,7 +415,7 @@ class RegisterRestServlet(RestServlet):
 
         self.hs = hs
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.identity_handler = hs.get_identity_handler()
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 2cab83c4e6..487ea38b55 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -85,7 +85,7 @@ class RelationPaginationServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self._event_serializer = hs.get_event_client_serializer()
         self.event_handler = hs.get_event_handler()
@@ -190,7 +190,7 @@ class RelationAggregationPaginationServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.event_handler = hs.get_event_handler()
 
     async def on_GET(
@@ -282,7 +282,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self._event_serializer = hs.get_event_client_serializer()
         self.event_handler = hs.get_event_handler()
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index d4a4adb50c..6e962a4532 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/report_event.py
@@ -38,7 +38,7 @@ class ReportEventRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_POST(
         self, request: SynapseRequest, room_id: str, event_id: str
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 90355e44b2..8a06ab8c5f 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -477,7 +477,7 @@ class RoomMemberListRestServlet(RestServlet):
         super().__init__()
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, room_id: str
@@ -553,7 +553,7 @@ class RoomMessageListRestServlet(RestServlet):
         self._hs = hs
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, room_id: str
@@ -621,7 +621,7 @@ class RoomInitialSyncRestServlet(RestServlet):
         super().__init__()
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, room_id: str
@@ -642,7 +642,7 @@ class RoomEventServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.clock = hs.get_clock()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self.event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
         self.auth = hs.get_auth()
@@ -1027,7 +1027,7 @@ class JoinedRoomsRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
@@ -1116,7 +1116,7 @@ class TimestampLookupRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self._auth = hs.get_auth()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self.timestamp_lookup_handler = hs.get_timestamp_lookup_handler()
 
     async def on_GET(
@@ -1141,73 +1141,6 @@ class TimestampLookupRestServlet(RestServlet):
         }
 
 
-class RoomSpaceSummaryRestServlet(RestServlet):
-    PATTERNS = (
-        re.compile(
-            "^/_matrix/client/unstable/org.matrix.msc2946"
-            "/rooms/(?P<room_id>[^/]*)/spaces$"
-        ),
-    )
-
-    def __init__(self, hs: "HomeServer"):
-        super().__init__()
-        self._auth = hs.get_auth()
-        self._room_summary_handler = hs.get_room_summary_handler()
-
-    async def on_GET(
-        self, request: SynapseRequest, room_id: str
-    ) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request, allow_guest=True)
-
-        max_rooms_per_space = parse_integer(request, "max_rooms_per_space")
-        if max_rooms_per_space is not None and max_rooms_per_space < 0:
-            raise SynapseError(
-                400,
-                "Value for 'max_rooms_per_space' must be a non-negative integer",
-                Codes.BAD_JSON,
-            )
-
-        return 200, await self._room_summary_handler.get_space_summary(
-            requester.user.to_string(),
-            room_id,
-            suggested_only=parse_boolean(request, "suggested_only", default=False),
-            max_rooms_per_space=max_rooms_per_space,
-        )
-
-    # TODO When switching to the stable endpoint, remove the POST handler.
-    async def on_POST(
-        self, request: SynapseRequest, room_id: str
-    ) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request, allow_guest=True)
-        content = parse_json_object_from_request(request)
-
-        suggested_only = content.get("suggested_only", False)
-        if not isinstance(suggested_only, bool):
-            raise SynapseError(
-                400, "'suggested_only' must be a boolean", Codes.BAD_JSON
-            )
-
-        max_rooms_per_space = content.get("max_rooms_per_space")
-        if max_rooms_per_space is not None:
-            if not isinstance(max_rooms_per_space, int):
-                raise SynapseError(
-                    400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON
-                )
-            if max_rooms_per_space < 0:
-                raise SynapseError(
-                    400,
-                    "Value for 'max_rooms_per_space' must be a non-negative integer",
-                    Codes.BAD_JSON,
-                )
-
-        return 200, await self._room_summary_handler.get_space_summary(
-            requester.user.to_string(),
-            room_id,
-            suggested_only=suggested_only,
-            max_rooms_per_space=max_rooms_per_space,
-        )
-
-
 class RoomHierarchyRestServlet(RestServlet):
     PATTERNS = (
         re.compile(
@@ -1301,7 +1234,6 @@ def register_servlets(
     RoomRedactEventRestServlet(hs).register(http_server)
     RoomTypingRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
-    RoomSpaceSummaryRestServlet(hs).register(http_server)
     RoomHierarchyRestServlet(hs).register(http_server)
     if hs.config.experimental.msc3266_enabled:
         RoomSummaryRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 4b6be38327..0048973e59 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -75,7 +75,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.event_creation_handler = hs.get_event_creation_handler()
         self.auth = hs.get_auth()
         self.room_batch_handler = hs.get_room_batch_handler()
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py
index 09a46737de..e669fa7890 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/shared_rooms.py
@@ -41,7 +41,7 @@ class UserSharedRoomsServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.user_directory_active = hs.config.server.update_user_directory
 
     async def on_GET(
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index f9615da525..f3018ff690 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.sync_handler = hs.get_sync_handler()
         self.clock = hs.get_clock()
         self.filtering = hs.get_filtering()
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index c88cb9367c..ca638755c7 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -39,7 +39,7 @@ class TagListServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str, room_id: str
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 00f29344a8..2e5d0e4e22 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -99,6 +99,8 @@ class VersionsRestServlet(RestServlet):
                     "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
                     # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
                     "org.matrix.msc3030": self.config.experimental.msc3030_enabled,
+                    # Adds support for thread relations, per MSC3440.
+                    "org.matrix.msc3440": self.config.experimental.msc3440_enabled,
                 },
             },
         )
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 3d2afacc50..25f9ea285b 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -78,7 +78,7 @@ class ConsentResource(DirectServeHtmlResource):
         super().__init__()
 
         self.hs = hs
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.registration_handler = hs.get_registration_handler()
 
         # this is required by the request_handler wrapper
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 3923ba8439..3525d6ae54 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -94,7 +94,7 @@ class RemoteKey(DirectServeJsonResource):
         super().__init__()
 
         self.fetcher = ServerKeyFetcher(hs)
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.federation_domain_whitelist = (
             hs.config.federation.federation_domain_whitelist
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 71b9a34b14..6c414402bd 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -75,7 +75,7 @@ class MediaRepository:
         self.client = hs.get_federation_http_client()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.max_upload_size = hs.config.media.max_upload_size
         self.max_image_pixels = hs.config.media.max_image_pixels
 
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index c08b60d10a..14ea88b240 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -134,7 +134,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         self.filepaths = media_repo.filepaths
         self.max_spider_size = hs.config.media.max_spider_size
         self.server_name = hs.hostname
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.client = SimpleHttpClient(
             hs,
             treq_args={"browser_like_redirects": True},
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index ed91ef5a42..53b1565243 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -50,7 +50,7 @@ class ThumbnailResource(DirectServeJsonResource):
     ):
         super().__init__()
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.media_repo = media_repo
         self.media_storage = media_storage
         self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index fde28d08cb..e73e431dc9 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -37,7 +37,7 @@ class UploadResource(DirectServeJsonResource):
 
         self.media_repo = media_repo
         self.filepaths = media_repo.filepaths
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.auth = hs.get_auth()
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 28a67f04e3..6ac9dbc7c9 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -44,7 +44,7 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
         super().__init__()
 
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self._local_threepid_handling_disabled_due_to_email_config = (
             hs.config.email.local_threepid_handling_disabled_due_to_email_config
diff --git a/synapse/server.py b/synapse/server.py
index 564afdcb96..b5e2a319bc 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -17,7 +17,7 @@
 # homeservers; either as a full homeserver as a real application, or a small
 # partial one for unit test mocking.
 
-# Imports required for the default HomeServer() implementation
+
 import abc
 import functools
 import logging
@@ -62,6 +62,7 @@ from synapse.federation.sender import AbstractFederationSender, FederationSender
 from synapse.federation.transport.client import TransportLayerClient
 from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
 from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
+from synapse.handlers.account import AccountHandler
 from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.admin import AdminHandler
@@ -133,7 +134,7 @@ from synapse.server_notices.worker_server_notices_sender import (
     WorkerServerNoticesSender,
 )
 from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import Databases, DataStore, Storage
+from synapse.storage import Databases, Storage
 from synapse.streams.events import EventSources
 from synapse.types import DomainSpecificString, ISynapseReactor
 from synapse.util import Clock
@@ -224,7 +225,7 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     # This is overridden in derived application classes
     # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
-    # instantiated during setup() for future return by get_datastore()
+    # instantiated during setup() for future return by get_datastores()
     DATASTORE_CLASS = abc.abstractproperty()
 
     tls_server_context_factory: Optional[IOpenSSLContextFactory]
@@ -354,12 +355,6 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_clock(self) -> Clock:
         return Clock(self._reactor)
 
-    def get_datastore(self) -> DataStore:
-        if not self.datastores:
-            raise Exception("HomeServer.setup must be called before getting datastores")
-
-        return self.datastores.main
-
     def get_datastores(self) -> Databases:
         if not self.datastores:
             raise Exception("HomeServer.setup must be called before getting datastores")
@@ -373,7 +368,7 @@ class HomeServer(metaclass=abc.ABCMeta):
     @cache_in_self
     def get_registration_ratelimiter(self) -> Ratelimiter:
         return Ratelimiter(
-            store=self.get_datastore(),
+            store=self.get_datastores().main,
             clock=self.get_clock(),
             rate_hz=self.config.ratelimiting.rc_registration.per_second,
             burst_count=self.config.ratelimiting.rc_registration.burst_count,
@@ -808,6 +803,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return ExternalCache(self)
 
     @cache_in_self
+    def get_account_handler(self) -> AccountHandler:
+        return AccountHandler(self)
+
+    @cache_in_self
     def get_outbound_redis_connection(self) -> "RedisProtocol":
         """
         The Redis connection used for replication.
@@ -842,7 +841,7 @@ class HomeServer(metaclass=abc.ABCMeta):
     @cache_in_self
     def get_request_ratelimiter(self) -> RequestRatelimiter:
         return RequestRatelimiter(
-            self.get_datastore(),
+            self.get_datastores().main,
             self.get_clock(),
             self.config.ratelimiting.rc_message,
             self.config.ratelimiting.rc_admin_redaction,
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index e09a25591f..698ca742ed 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -32,7 +32,7 @@ class ConsentServerNotices:
 
     def __init__(self, hs: "HomeServer"):
         self._server_notices_manager = hs.get_server_notices_manager()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
         self._users_in_progress: Set[str] = set()
 
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 8522930b50..015dd08f05 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -36,7 +36,7 @@ class ResourceLimitsServerNotices:
 
     def __init__(self, hs: "HomeServer"):
         self._server_notices_manager = hs.get_server_notices_manager()
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._auth = hs.get_auth()
         self._config = hs.config
         self._resouce_limited = False
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 0cf60236f8..7b4814e049 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -29,7 +29,7 @@ SERVER_NOTICE_ROOM_TAG = "m.server_notice"
 
 class ServerNoticesManager:
     def __init__(self, hs: "HomeServer"):
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
         self._config = hs.config
         self._account_data_handler = hs.get_account_data_handler()
         self._room_creation_handler = hs.get_room_creation_handler()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 67e8bc6ec2..6babd5963c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -126,7 +126,7 @@ class StateHandler:
 
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state_store = hs.get_storage().state
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -258,7 +258,10 @@ class StateHandler:
         return await self.store.get_joined_hosts(room_id, entry)
 
     async def compute_event_context(
-        self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
+        self,
+        event: EventBase,
+        old_state: Optional[Iterable[EventBase]] = None,
+        partial_state: bool = False,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
 
@@ -273,6 +276,8 @@ class StateHandler:
                 calculated from existing events. This is normally only specified
                 when receiving an event from federation where we don't have the
                 prev events for, e.g. when backfilling.
+            partial_state: True if `old_state` is partial and omits non-critical
+                membership events
         Returns:
             The event context.
         """
@@ -295,8 +300,28 @@ class StateHandler:
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
-            logger.debug("calling resolve_state_groups from compute_event_context")
 
+            # partial_state should not be set explicitly in this case:
+            # we work it out dynamically
+            assert not partial_state
+
+            # if any of the prev-events have partial state, so do we.
+            # (This is slightly racy - the prev-events might get fixed up before we use
+            # their states - but I don't think that really matters; it just means we
+            # might redundantly recalculate the state for this event later.)
+            prev_event_ids = event.prev_event_ids()
+            incomplete_prev_events = await self.store.get_partial_state_events(
+                prev_event_ids
+            )
+            if any(incomplete_prev_events.values()):
+                logger.debug(
+                    "New/incoming event %s refers to prev_events %s with partial state",
+                    event.event_id,
+                    [k for (k, v) in incomplete_prev_events.items() if v],
+                )
+                partial_state = True
+
+            logger.debug("calling resolve_state_groups from compute_event_context")
             entry = await self.resolve_state_groups_for_events(
                 event.room_id, event.prev_event_ids()
             )
@@ -342,6 +367,7 @@ class StateHandler:
                 prev_state_ids=state_ids_before_event,
                 prev_group=state_group_before_event_prev_group,
                 delta_ids=deltas_to_state_group_before_event,
+                partial_state=partial_state,
             )
 
         #
@@ -373,6 +399,7 @@ class StateHandler:
             prev_state_ids=state_ids_before_event,
             prev_group=state_group_before_event,
             delta_ids=delta_ids,
+            partial_state=partial_state,
         )
 
     @measure_func()
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index cfe887b7f7..ce3d1d4e94 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -24,6 +24,7 @@ from synapse.storage.prepare_database import prepare_database
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -44,7 +45,7 @@ class Databases(Generic[DataStoreT]):
     """
 
     databases: List[DatabasePool]
-    main: DataStoreT
+    main: "DataStore"  # FIXME: #11165: actually an instance of `main_store_class`
     state: StateGroupDataStore
     persist_events: Optional[PersistEventsStore]
 
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 304814af5d..0694446558 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -20,14 +20,18 @@ from synapse.appservice import (
     ApplicationService,
     ApplicationServiceState,
     AppServiceTransaction,
+    TransactionOneTimeKeyCounts,
+    TransactionUnusedFallbackKeys,
 )
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import db_to_json
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.types import JsonDict
 from synapse.util import json_encoder
+from synapse.util.caches.descriptors import _CacheContext, cached
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -56,7 +60,7 @@ def _make_exclusive_regex(
     return exclusive_user_pattern
 
 
-class ApplicationServiceWorkerStore(SQLBaseStore):
+class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -124,6 +128,18 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
                 return service
         return None
 
+    @cached(iterable=True, cache_context=True)
+    async def get_app_service_users_in_room(
+        self,
+        room_id: str,
+        app_service: "ApplicationService",
+        cache_context: _CacheContext,
+    ) -> List[str]:
+        users_in_room = await self.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
+        return list(filter(app_service.is_interested_in_user, users_in_room))
+
 
 class ApplicationServiceStore(ApplicationServiceWorkerStore):
     # This is currently empty due to there not being any AS storage functions
@@ -199,6 +215,8 @@ class ApplicationServiceTransactionWorkerStore(
         events: List[EventBase],
         ephemeral: List[JsonDict],
         to_device_messages: List[JsonDict],
+        one_time_key_counts: TransactionOneTimeKeyCounts,
+        unused_fallback_keys: TransactionUnusedFallbackKeys,
     ) -> AppServiceTransaction:
         """Atomically creates a new transaction for this application service
         with the given list of events. Ephemeral events are NOT persisted to the
@@ -209,6 +227,10 @@ class ApplicationServiceTransactionWorkerStore(
             events: A list of persistent events to put in the transaction.
             ephemeral: A list of ephemeral events to put in the transaction.
             to_device_messages: A list of to-device messages to put in the transaction.
+            one_time_key_counts: Counts of remaining one-time keys for relevant
+                appservice devices in the transaction.
+            unused_fallback_keys: Lists of unused fallback keys for relevant
+                appservice devices in the transaction.
 
         Returns:
             A new transaction.
@@ -244,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore(
                 events=events,
                 ephemeral=ephemeral,
                 to_device_messages=to_device_messages,
+                one_time_key_counts=one_time_key_counts,
+                unused_fallback_keys=unused_fallback_keys,
             )
 
         return await self.db_pool.runInteraction(
@@ -335,12 +359,17 @@ class ApplicationServiceTransactionWorkerStore(
 
         events = await self.get_events_as_list(event_ids)
 
+        # TODO: to-device messages, one-time key counts and unused fallback keys
+        #       are not yet populated for catch-up transactions.
+        #       We likely want to populate those for reliability.
         return AppServiceTransaction(
             service=service,
             id=entry["txn_id"],
             events=events,
             ephemeral=[],
             to_device_messages=[],
+            one_time_key_counts={},
+            unused_fallback_keys={},
         )
 
     def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1f8447b507..9b293475c8 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -29,6 +29,10 @@ import attr
 from canonicaljson import encode_canonical_json
 
 from synapse.api.constants import DeviceKeyAlgorithms
+from synapse.appservice import (
+    TransactionOneTimeKeyCounts,
+    TransactionUnusedFallbackKeys,
+)
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -439,6 +443,114 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+    async def count_bulk_e2e_one_time_keys_for_as(
+        self, user_ids: Collection[str]
+    ) -> TransactionOneTimeKeyCounts:
+        """
+        Counts, in bulk, the one-time keys for all the users specified.
+        Intended to be used by application services for populating OTK counts in
+        transactions.
+
+        Return structure is of the shape:
+          user_id -> device_id -> algorithm -> count
+          Empty algorithm -> count dicts are created if needed to represent a
+          lack of unused one-time keys.
+        """
+
+        def _count_bulk_e2e_one_time_keys_txn(
+            txn: LoggingTransaction,
+        ) -> TransactionOneTimeKeyCounts:
+            user_in_where_clause, user_parameters = make_in_list_sql_clause(
+                self.database_engine, "user_id", user_ids
+            )
+            sql = f"""
+                SELECT user_id, device_id, algorithm, COUNT(key_id)
+                FROM devices
+                LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
+                WHERE {user_in_where_clause}
+                GROUP BY user_id, device_id, algorithm
+            """
+            txn.execute(sql, user_parameters)
+
+            result: TransactionOneTimeKeyCounts = {}
+
+            for user_id, device_id, algorithm, count in txn:
+                # We deliberately construct empty dictionaries for
+                # users and devices without any unused one-time keys.
+                # We *could* omit these empty dicts if there have been no
+                # changes since the last transaction, but we currently don't
+                # do any change tracking!
+                device_count_by_algo = result.setdefault(user_id, {}).setdefault(
+                    device_id, {}
+                )
+                if algorithm is not None:
+                    # algorithm will be None if this device has no keys.
+                    device_count_by_algo[algorithm] = count
+
+            return result
+
+        return await self.db_pool.runInteraction(
+            "count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn
+        )
+
+    async def get_e2e_bulk_unused_fallback_key_types(
+        self, user_ids: Collection[str]
+    ) -> TransactionUnusedFallbackKeys:
+        """
+        Finds, in bulk, the types of unused fallback keys for all the users specified.
+        Intended to be used by application services for populating unused fallback
+        keys in transactions.
+
+        Return structure is of the shape:
+          user_id -> device_id -> algorithms
+          Empty lists are created for devices if there are no unused fallback
+          keys. This matches the response structure of MSC3202.
+        """
+        if len(user_ids) == 0:
+            return {}
+
+        def _get_bulk_e2e_unused_fallback_keys_txn(
+            txn: LoggingTransaction,
+        ) -> TransactionUnusedFallbackKeys:
+            user_in_where_clause, user_parameters = make_in_list_sql_clause(
+                self.database_engine, "devices.user_id", user_ids
+            )
+            # We can't use USING here because we require the `.used` condition
+            # to be part of the JOIN condition so that we generate empty lists
+            # when all keys are used (as opposed to just when there are no keys at all).
+            sql = f"""
+                SELECT devices.user_id, devices.device_id, algorithm
+                FROM devices
+                LEFT JOIN e2e_fallback_keys_json AS fallback_keys
+                    ON devices.user_id = fallback_keys.user_id
+                    AND devices.device_id = fallback_keys.device_id
+                    AND NOT fallback_keys.used
+                WHERE
+                    {user_in_where_clause}
+            """
+            txn.execute(sql, user_parameters)
+
+            result: TransactionUnusedFallbackKeys = {}
+
+            for user_id, device_id, algorithm in txn:
+                # We deliberately construct empty dictionaries and lists for
+                # users and devices without any unused fallback keys.
+                # We *could* omit these empty dicts if there have been no
+                # changes since the last transaction, but we currently don't
+                # do any change tracking!
+                device_unused_keys = result.setdefault(user_id, {}).setdefault(
+                    device_id, []
+                )
+                if algorithm is not None:
+                    # algorithm will be None if this device has no keys.
+                    device_unused_keys.append(algorithm)
+
+            return result
+
+        return await self.db_pool.runInteraction(
+            "_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn
+        )
+
     async def set_e2e_fallback_keys(
         self, user_id: str, device_id: str, fallback_keys: JsonDict
     ) -> None:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a1d7a9b413..ca2a9ba9d1 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -130,7 +130,7 @@ class PersistEventsStore:
         *,
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
-        new_forward_extremeties: Dict[str, List[str]],
+        new_forward_extremities: Dict[str, Set[str]],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
@@ -143,7 +143,7 @@ class PersistEventsStore:
                 the room based on forward extremities
             state_delta_for_room: Map from room_id to the delta to apply to
                 room state
-            new_forward_extremities: Map from room_id to list of event IDs
+            new_forward_extremities: Map from room_id to set of event IDs
                 that are the new forward extremities of the room.
             use_negative_stream_ordering: Whether to start stream_ordering on
                 the negative side and decrement. This should be set as True
@@ -193,7 +193,7 @@ class PersistEventsStore:
                 events_and_contexts=events_and_contexts,
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
-                new_forward_extremeties=new_forward_extremeties,
+                new_forward_extremities=new_forward_extremities,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
@@ -220,7 +220,7 @@ class PersistEventsStore:
             for room_id, new_state in current_state_for_room.items():
                 self.store.get_current_state_ids.prefill((room_id,), new_state)
 
-            for room_id, latest_event_ids in new_forward_extremeties.items():
+            for room_id, latest_event_ids in new_forward_extremities.items():
                 self.store.get_latest_event_ids_in_room.prefill(
                     (room_id,), list(latest_event_ids)
                 )
@@ -334,8 +334,8 @@ class PersistEventsStore:
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
         state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
-        new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
-    ):
+        new_forward_extremities: Optional[Dict[str, Set[str]]] = None,
+    ) -> None:
         """Insert some number of room events into the necessary database tables.
 
         Rejected events are only inserted into the events table, the events_json table,
@@ -353,13 +353,13 @@ class PersistEventsStore:
                 from the database. This is useful when retrying due to
                 IntegrityError.
             state_delta_for_room: The current-state delta for each room.
-            new_forward_extremetie: The new forward extremities for each room.
+            new_forward_extremities: The new forward extremities for each room.
                 For each room, a list of the event ids which are the forward
                 extremities.
 
         """
         state_delta_for_room = state_delta_for_room or {}
-        new_forward_extremeties = new_forward_extremeties or {}
+        new_forward_extremities = new_forward_extremities or {}
 
         all_events_and_contexts = events_and_contexts
 
@@ -372,7 +372,7 @@ class PersistEventsStore:
 
         self._update_forward_extremities_txn(
             txn,
-            new_forward_extremities=new_forward_extremeties,
+            new_forward_extremities=new_forward_extremities,
             max_stream_order=max_stream_order,
         )
 
@@ -1158,7 +1158,10 @@ class PersistEventsStore:
             )
 
     def _update_forward_extremities_txn(
-        self, txn, new_forward_extremities, max_stream_order
+        self,
+        txn: LoggingTransaction,
+        new_forward_extremities: Dict[str, Set[str]],
+        max_stream_order: int,
     ):
         for room_id in new_forward_extremities.keys():
             self.db_pool.simple_delete_txn(
@@ -1473,10 +1476,10 @@ class PersistEventsStore:
 
     def _update_metadata_tables_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         *,
-        events_and_contexts,
-        all_events_and_contexts,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        all_events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool = False,
     ):
         """Update all the miscellaneous tables for new events
@@ -1953,20 +1956,20 @@ class PersistEventsStore:
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
 
-    def _store_room_topic_txn(self, txn, event):
-        if hasattr(event, "content") and "topic" in event.content:
+    def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
+        if isinstance(event.content.get("topic"), str):
             self.store_event_search_txn(
                 txn, event, "content.topic", event.content["topic"]
             )
 
-    def _store_room_name_txn(self, txn, event):
-        if hasattr(event, "content") and "name" in event.content:
+    def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
+        if isinstance(event.content.get("name"), str):
             self.store_event_search_txn(
                 txn, event, "content.name", event.content["name"]
             )
 
-    def _store_room_message_txn(self, txn, event):
-        if hasattr(event, "content") and "body" in event.content:
+    def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
+        if isinstance(event.content.get("body"), str):
             self.store_event_search_txn(
                 txn, event, "content.body", event.content["body"]
             )
@@ -2142,6 +2145,14 @@ class PersistEventsStore:
         state_groups = {}
         for event, context in events_and_contexts:
             if event.internal_metadata.is_outlier():
+                # double-check that we don't have any events that claim to be outliers
+                # *and* have partial state (which is meaningless: we should have no
+                # state at all for an outlier)
+                if context.partial_state:
+                    raise ValueError(
+                        "Outlier event %s claims to have partial state", event.event_id
+                    )
+
                 continue
 
             # if the event was rejected, just give it the same state as its
@@ -2152,6 +2163,23 @@ class PersistEventsStore:
 
             state_groups[event.event_id] = context.state_group
 
+        # if we have partial state for these events, record the fact. (This happens
+        # here rather than in _store_event_txn because it also needs to happen when
+        # we de-outlier an event.)
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="partial_state_events",
+            keys=("room_id", "event_id"),
+            values=[
+                (
+                    event.room_id,
+                    event.event_id,
+                )
+                for event, ctx in events_and_contexts
+                if ctx.partial_state
+            ],
+        )
+
         self.db_pool.simple_upsert_many_txn(
             txn,
             table="event_to_state_groups",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 2a255d1031..26784f755e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1953,3 +1953,31 @@ class EventsWorkerStore(SQLBaseStore):
             "get_event_id_for_timestamp_txn",
             get_event_id_for_timestamp_txn,
         )
+
+    @cachedList("is_partial_state_event", list_name="event_ids")
+    async def get_partial_state_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, bool]:
+        """Checks which of the given events have partial state"""
+        result = await self.db_pool.simple_select_many_batch(
+            table="partial_state_events",
+            column="event_id",
+            iterable=event_ids,
+            retcols=["event_id"],
+            desc="get_partial_state_events",
+        )
+        # convert the result to a dict, to make @cachedList work
+        partial = {r["event_id"] for r in result}
+        return {e_id: e_id in partial for e_id in event_ids}
+
+    @cached()
+    async def is_partial_state_event(self, event_id: str) -> bool:
+        """Checks if the given event has partial state"""
+        result = await self.db_pool.simple_select_one_onecol(
+            table="partial_state_events",
+            keyvalues={"event_id": event_id},
+            retcol="1",
+            allow_none=True,
+            desc="is_partial_state_event",
+        )
+        return result is not None
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 8f09dd8e87..e9a0cdc6be 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -112,7 +112,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         for tp in self.hs.config.server.mau_limits_reserved_threepids[
             : self.hs.config.server.max_mau_value
         ]:
-            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+            user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
                 tp["medium"], canonicalise_email(tp["address"])
             )
             if user_id:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0416df64ce..94068940b9 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -20,6 +20,7 @@ from typing import (
     TYPE_CHECKING,
     Any,
     Awaitable,
+    Collection,
     Dict,
     List,
     Optional,
@@ -1543,6 +1544,42 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             lock=False,
         )
 
+    async def store_partial_state_room(
+        self,
+        room_id: str,
+        servers: Collection[str],
+    ) -> None:
+        """Mark the given room as containing events with partial state
+
+        Args:
+            room_id: the ID of the room
+            servers: other servers known to be in the room
+        """
+        await self.db_pool.runInteraction(
+            "store_partial_state_room",
+            self._store_partial_state_room_txn,
+            room_id,
+            servers,
+        )
+
+    @staticmethod
+    def _store_partial_state_room_txn(
+        txn: LoggingTransaction, room_id: str, servers: Collection[str]
+    ) -> None:
+        DatabasePool.simple_insert_txn(
+            txn,
+            table="partial_state_rooms",
+            values={
+                "room_id": room_id,
+            },
+        )
+        DatabasePool.simple_insert_many_txn(
+            txn,
+            table="partial_state_rooms_servers",
+            keys=("room_id", "server_name"),
+            values=((room_id, s) for s in servers),
+        )
+
     async def maybe_store_room_on_outlier_membership(
         self, room_id: str, room_version: RoomVersion
     ) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 59e6ab983a..683a5e5d13 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -115,6 +115,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
+    EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings"
 
     def __init__(
         self,
@@ -147,6 +148,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
             self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
+        )
+
     async def _background_reindex_search(self, progress, batch_size):
         # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -372,6 +377,27 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
         return num_rows
 
+    async def _background_delete_non_strings(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Deletes rows with non-string `value`s from `event_search` if using sqlite.
+
+        Prior to Synapse 1.44.0, malformed events received over federation could cause integers
+        to be inserted into the `event_search` table when using sqlite.
+        """
+
+        def delete_non_strings_txn(txn: LoggingTransaction) -> None:
+            txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'")
+
+        await self.db_pool.runInteraction(
+            self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn
+        )
+
+        await self.db_pool.updates._end_background_update(
+            self.EVENT_SEARCH_DELETE_NON_STRINGS
+        )
+        return 1
+
 
 class SearchStore(SearchBackgroundUpdateStore):
     def __init__(
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 2fb3e65192..417aef1dbc 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -42,6 +42,16 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
+def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
+    v = KNOWN_ROOM_VERSIONS.get(room_version_id)
+    if not v:
+        raise UnsupportedRoomVersionError(
+            "Room %s uses a room version %s which is no longer supported"
+            % (room_id, room_version_id)
+        )
+    return v
+
+
 # this inherits from EventsWorkerStore because it calls self.get_events
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers."""
@@ -62,11 +72,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 Typically this happens if support for the room's version has been
                 removed from Synapse.
         """
-        return await self.db_pool.runInteraction(
-            "get_room_version_txn",
-            self.get_room_version_txn,
-            room_id,
-        )
+        room_version_id = await self.get_room_version_id(room_id)
+        return _retrieve_and_check_room_version(room_id, room_version_id)
 
     def get_room_version_txn(
         self, txn: LoggingTransaction, room_id: str
@@ -82,15 +89,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 removed from Synapse.
         """
         room_version_id = self.get_room_version_id_txn(txn, room_id)
-        v = KNOWN_ROOM_VERSIONS.get(room_version_id)
-
-        if not v:
-            raise UnsupportedRoomVersionError(
-                "Room %s uses a room version %s which is no longer supported"
-                % (room_id, room_version_id)
-            )
-
-        return v
+        return _retrieve_and_check_room_version(room_id, room_version_id)
 
     @cached(max_entries=10000)
     async def get_room_version_id(self, room_id: str) -> str:
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 428d66a617..7d543fdbe0 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -427,21 +427,21 @@ class EventsPersistenceStorage:
             # NB: Assumes that we are only persisting events for one room
             # at a time.
 
-            # map room_id->list[event_ids] giving the new forward
+            # map room_id->set[event_ids] giving the new forward
             # extremities in each room
-            new_forward_extremeties = {}
+            new_forward_extremities: Dict[str, Set[str]] = {}
 
             # map room_id->(type,state_key)->event_id tracking the full
             # state in each room after adding these events.
             # This is simply used to prefill the get_current_state_ids
             # cache
-            current_state_for_room = {}
+            current_state_for_room: Dict[str, StateMap[str]] = {}
 
             # map room_id->(to_delete, to_insert) where to_delete is a list
             # of type/state keys to remove from current state, and to_insert
             # is a map (type,key)->event_id giving the state delta in each
             # room
-            state_delta_for_room = {}
+            state_delta_for_room: Dict[str, DeltaState] = {}
 
             # Set of remote users which were in rooms the server has left. We
             # should check if we still share any rooms and if not we mark their
@@ -460,14 +460,13 @@ class EventsPersistenceStorage:
                         )
 
                     for room_id, ev_ctx_rm in events_by_room.items():
-                        latest_event_ids = (
+                        latest_event_ids = set(
                             await self.main_store.get_latest_event_ids_in_room(room_id)
                         )
                         new_latest_event_ids = await self._calculate_new_extremities(
                             room_id, ev_ctx_rm, latest_event_ids
                         )
 
-                        latest_event_ids = set(latest_event_ids)
                         if new_latest_event_ids == latest_event_ids:
                             # No change in extremities, so no change in state
                             continue
@@ -478,7 +477,7 @@ class EventsPersistenceStorage:
                         # extremities, so we'll `continue` above and skip this bit.)
                         assert new_latest_event_ids, "No forward extremities left!"
 
-                        new_forward_extremeties[room_id] = new_latest_event_ids
+                        new_forward_extremities[room_id] = new_latest_event_ids
 
                         len_1 = (
                             len(latest_event_ids) == 1
@@ -533,7 +532,7 @@ class EventsPersistenceStorage:
                             # extremities, so we'll `continue` above and skip this bit.)
                             assert new_latest_event_ids, "No forward extremities left!"
 
-                            new_forward_extremeties[room_id] = new_latest_event_ids
+                            new_forward_extremities[room_id] = new_latest_event_ids
 
                         # If either are not None then there has been a change,
                         # and we need to work out the delta (or use that
@@ -567,7 +566,7 @@ class EventsPersistenceStorage:
                             )
                             if not is_still_joined:
                                 logger.info("Server no longer in room %s", room_id)
-                                latest_event_ids = []
+                                latest_event_ids = set()
                                 current_state = {}
                                 delta.no_longer_in_room = True
 
@@ -582,7 +581,7 @@ class EventsPersistenceStorage:
                 chunk,
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
-                new_forward_extremeties=new_forward_extremeties,
+                new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
                 inhibit_local_membership_updates=backfilled,
             )
@@ -596,7 +595,7 @@ class EventsPersistenceStorage:
         room_id: str,
         event_contexts: List[Tuple[EventBase, EventContext]],
         latest_event_ids: Collection[str],
-    ):
+    ) -> Set[str]:
         """Calculates the new forward extremities for a room given events to
         persist.
 
@@ -906,9 +905,9 @@ class EventsPersistenceStorage:
             # Ideally we'd figure out a way of still being able to drop old
             # dummy events that reference local events, but this is good enough
             # as a first cut.
-            events_to_check = [event]
+            events_to_check: Collection[EventBase] = [event]
             while events_to_check:
-                new_events = set()
+                new_events: Set[str] = set()
                 for event_to_check in events_to_check:
                     if self.is_mine_id(event_to_check.sender):
                         if event_to_check.type != EventTypes.Dummy:
diff --git a/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql
new file mode 100644
index 0000000000..815c0cc390
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/04partial_state_rooms.sql
@@ -0,0 +1,41 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- rooms which we have done a partial-state-style join to
+CREATE TABLE IF NOT EXISTS partial_state_rooms (
+    room_id TEXT PRIMARY KEY,
+    FOREIGN KEY(room_id) REFERENCES rooms(room_id)
+);
+
+-- a list of remote servers we believe are in the room
+CREATE TABLE IF NOT EXISTS partial_state_rooms_servers (
+    room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
+    server_name TEXT NOT NULL,
+    UNIQUE(room_id, server_name)
+);
+
+-- a list of events with partial state. We can't store this in the `events` table
+-- itself, because `events` is meant to be append-only.
+CREATE TABLE IF NOT EXISTS partial_state_events (
+    -- the room_id is denormalised for efficient indexing (the canonical source is `events`)
+    room_id TEXT NOT NULL REFERENCES partial_state_rooms(room_id),
+    event_id TEXT NOT NULL REFERENCES events(event_id),
+    UNIQUE(event_id)
+);
+
+CREATE INDEX IF NOT EXISTS partial_state_events_room_id_idx
+     ON partial_state_events (room_id);
+
+
diff --git a/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite
new file mode 100644
index 0000000000..140df65264
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/05_delete_non_strings_from_event_search.sql.sqlite
@@ -0,0 +1,22 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Delete rows with non-string `value`s from `event_search` if using sqlite.
+--
+-- Prior to Synapse 1.44.0, malformed events received over federation could
+-- cause integers to be inserted into the `event_search` table.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (6805, 'event_search_sqlite_delete_non_strings', '{}');
diff --git a/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py
new file mode 100644
index 0000000000..a2ec4fc26e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/05partial_state_rooms_triggers.py
@@ -0,0 +1,72 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This migration adds triggers to the partial_state_events tables to enforce uniqueness
+
+Triggers cannot be expressed in .sql files, so we have to use a separate file.
+"""
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+    # complain if the room_id in partial_state_events doesn't match
+    # that in `events`. We already have a fk constraint which ensures that the event
+    # exists in `events`, so all we have to do is raise if there is a row with a
+    # matching stream_ordering but not a matching room_id.
+    if isinstance(database_engine, Sqlite3Engine):
+        cur.execute(
+            """
+            CREATE TRIGGER IF NOT EXISTS partial_state_events_bad_room_id
+            BEFORE INSERT ON partial_state_events
+            FOR EACH ROW
+            BEGIN
+                SELECT RAISE(ABORT, 'Incorrect room_id in partial_state_events')
+                WHERE EXISTS (
+                    SELECT 1 FROM events
+                    WHERE events.event_id = NEW.event_id
+                       AND events.room_id != NEW.room_id
+                );
+            END;
+            """
+        )
+    elif isinstance(database_engine, PostgresEngine):
+        cur.execute(
+            """
+            CREATE OR REPLACE FUNCTION check_partial_state_events() RETURNS trigger AS $BODY$
+            BEGIN
+                IF EXISTS (
+                    SELECT 1 FROM events
+                    WHERE events.event_id = NEW.event_id
+                       AND events.room_id != NEW.room_id
+                ) THEN
+                    RAISE EXCEPTION 'Incorrect room_id in partial_state_events';
+                END IF;
+                RETURN NEW;
+            END;
+            $BODY$ LANGUAGE plpgsql;
+            """
+        )
+
+        cur.execute(
+            """
+            CREATE TRIGGER check_partial_state_events BEFORE INSERT OR UPDATE ON partial_state_events
+            FOR EACH ROW
+            EXECUTE PROCEDURE check_partial_state_events()
+            """
+        )
+    else:
+        raise NotImplementedError("Unknown database engine")
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 4ec2a713cf..fb8fe17295 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -48,7 +48,7 @@ class EventSources:
             # all the attributes of `_EventSourcesInner` are annotated.
             *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner))  # type: ignore[misc]
         )
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def get_current_token(self) -> StreamToken:
         push_rules_key = self.store.get_max_push_rules_stream_id()
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 511f52534b..58b4220ff3 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -81,7 +81,9 @@ json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
 
 
 def unwrapFirstError(failure: Failure) -> Failure:
-    # defer.gatherResults and DeferredLists wrap failures.
+    # Deprecated: you probably just want to catch defer.FirstError and reraise
+    # the subFailure's value, which will do a better job of preserving stacktraces.
+    # (actually, you probably want to use yieldable_gather_results anyway)
     failure.trap(defer.FirstError)
     return failure.value.subFailure  # type: ignore[union-attr]  # Issue in Twisted's annotations
 
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 3f7299aff7..60c03a66fd 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -29,6 +29,7 @@ from typing import (
     Hashable,
     Iterable,
     Iterator,
+    List,
     Optional,
     Set,
     Tuple,
@@ -51,7 +52,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
     run_in_background,
 )
-from synapse.util import Clock, unwrapFirstError
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
@@ -193,9 +194,9 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
 T = TypeVar("T")
 
 
-def concurrently_execute(
+async def concurrently_execute(
     func: Callable[[T], Any], args: Iterable[T], limit: int
-) -> defer.Deferred:
+) -> None:
     """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
 
@@ -221,20 +222,14 @@ def concurrently_execute(
     # We use `itertools.islice` to handle the case where the number of args is
     # less than the limit, avoiding needlessly spawning unnecessary background
     # tasks.
-    return make_deferred_yieldable(
-        defer.gatherResults(
-            [
-                run_in_background(_concurrently_execute_inner, value)
-                for value in itertools.islice(it, limit)
-            ],
-            consumeErrors=True,
-        )
-    ).addErrback(unwrapFirstError)
+    await yieldable_gather_results(
+        _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
+    )
 
 
-def yieldable_gather_results(
-    func: Callable, iter: Iterable, *args: Any, **kwargs: Any
-) -> defer.Deferred:
+async def yieldable_gather_results(
+    func: Callable[..., Awaitable[T]], iter: Iterable, *args: Any, **kwargs: Any
+) -> List[T]:
     """Executes the function with each argument concurrently.
 
     Args:
@@ -245,15 +240,30 @@ def yieldable_gather_results(
         **kwargs: Keyword arguments to be passed to each call to func
 
     Returns
-        Deferred[list]: Resolved when all functions have been invoked, or errors if
-        one of the function calls fails.
+        A list containing the results of the function
     """
-    return make_deferred_yieldable(
-        defer.gatherResults(
-            [run_in_background(func, item, *args, **kwargs) for item in iter],
-            consumeErrors=True,
+    try:
+        return await make_deferred_yieldable(
+            defer.gatherResults(
+                [run_in_background(func, item, *args, **kwargs) for item in iter],
+                consumeErrors=True,
+            )
         )
-    ).addErrback(unwrapFirstError)
+    except defer.FirstError as dfe:
+        # unwrap the error from defer.gatherResults.
+
+        # The raised exception's traceback only includes func() etc if
+        # the 'await' happens before the exception is thrown - ie if the failure
+        # happens *asynchronously* - otherwise Twisted throws away the traceback as it
+        # could be large.
+        #
+        # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
+        # we could throw Twisted into the fires of Mordor.
+
+        # suppress exception chaining, because the FirstError doesn't tell us anything
+        # very interesting.
+        assert isinstance(dfe.subFailure.value, BaseException)
+        raise dfe.subFailure.value from None
 
 
 T1 = TypeVar("T1")
@@ -545,7 +555,10 @@ class ReadWriteLock:
             finally:
                 with PreserveLoggingContext():
                     new_defer.callback(None)
-                if self.key_to_current_writer[key] == new_defer:
+                # `self.key_to_current_writer[key]` may be missing if there was another
+                # writer waiting for us and it completed entirely within the
+                # `new_defer.callback()` call above.
+                if self.key_to_current_writer.get(key) == new_defer:
                     self.key_to_current_writer.pop(key)
 
         return _ctx_manager()
@@ -655,3 +668,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
         return value
 
     return DoneAwaitable(value)
+
+
+def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
+    """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`.
+
+    Args:
+        deferred: The `Deferred` to protect against cancellation. Must not follow the
+            Synapse logcontext rules.
+
+    Returns:
+        A new `Deferred`, which will contain the result of the original `Deferred`,
+        but will not propagate cancellation through to the original. When cancelled,
+        the new `Deferred` will fail with a `CancelledError` and will not follow the
+        Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
+        the new `Deferred`.
+    """
+    new_deferred: defer.Deferred[T] = defer.Deferred()
+    deferred.chainDeferred(new_deferred)
+    return new_deferred
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index df4fb156c2..1cdead02f1 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,7 @@ import inspect
 import logging
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Dict,
     Generic,
@@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given an iterable of keys it looks in the cache to find any hits, then passes
-    the tuple of missing keys to the wrapped function.
+    the set of missing keys to the wrapped function.
 
-    Once wrapped, the function returns a Deferred which resolves to the list
-    of results.
+    Once wrapped, the function returns a Deferred which resolves to a Dict mapping from
+    input key to output value.
     """
 
     def __init__(
         self,
-        orig: Callable[..., Any],
+        orig: Callable[..., Awaitable[Dict]],
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
@@ -385,13 +386,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
     def __get__(
         self, obj: Optional[Any], objtype: Optional[Type] = None
-    ) -> Callable[..., Any]:
+    ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]:
         cached_method = getattr(obj, self.cached_method_name)
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
-        def wrapped(*args: Any, **kwargs: Any) -> Any:
+        def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
             # If we're passed a cache_context then we'll want to call its
             # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -444,39 +445,38 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     deferred: "defer.Deferred[Any]" = defer.Deferred()
                     deferreds_map[arg] = deferred
                     key = arg_to_cache_key(arg)
-                    cache.set(key, deferred, callback=invalidate_callback)
+                    cached_defers.append(
+                        cache.set(key, deferred, callback=invalidate_callback)
+                    )
 
                 def complete_all(res: Dict[Hashable, Any]) -> None:
-                    # the wrapped function has completed. It returns a
-                    # a dict. We can now resolve the observable deferreds in
-                    # the cache and update our own result map.
-                    for e in missing:
+                    # the wrapped function has completed. It returns a dict.
+                    # We can now update our own result map, and then resolve the
+                    # observable deferreds in the cache.
+                    for e, d1 in deferreds_map.items():
                         val = res.get(e, None)
-                        deferreds_map[e].callback(val)
+                        # make sure we update the results map before running the
+                        # deferreds, because as soon as we run the last deferred, the
+                        # gatherResults() below will complete and return the result
+                        # dict to our caller.
                         results[e] = val
+                        d1.callback(val)
 
-                def errback(f: Failure) -> Failure:
-                    # the wrapped function has failed. Invalidate any cache
-                    # entries we're supposed to be populating, and fail
-                    # their deferreds.
-                    for e in missing:
-                        key = arg_to_cache_key(e)
-                        cache.invalidate(key)
-                        deferreds_map[e].errback(f)
-
-                    # return the failure, to propagate to our caller.
-                    return f
+                def errback_all(f: Failure) -> None:
+                    # the wrapped function has failed. Propagate the failure into
+                    # the cache, which will invalidate the entry, and cause the
+                    # relevant cached_deferreds to fail, which will propagate the
+                    # failure to our caller.
+                    for d1 in deferreds_map.values():
+                        d1.errback(f)
 
                 args_to_call = dict(arg_dict)
-                # copy the missing set before sending it to the callee, to guard against
-                # modification.
-                args_to_call[self.list_name] = tuple(missing)
-
-                cached_defers.append(
-                    defer.maybeDeferred(
-                        preserve_fn(self.orig), **args_to_call
-                    ).addCallbacks(complete_all, errback)
-                )
+                args_to_call[self.list_name] = missing
+
+                # dispatch the call, and attach the two handlers
+                defer.maybeDeferred(
+                    preserve_fn(self.orig), **args_to_call
+                ).addCallbacks(complete_all, errback_all)
 
             if cached_defers:
                 d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
new file mode 100644
index 0000000000..3a1f6b3c75
--- /dev/null
+++ b/synapse/util/check_dependencies.py
@@ -0,0 +1,127 @@
+import logging
+from typing import Iterable, NamedTuple, Optional
+
+from packaging.requirements import Requirement
+
+DISTRIBUTION_NAME = "matrix-synapse"
+
+try:
+    from importlib import metadata
+except ImportError:
+    import importlib_metadata as metadata  # type: ignore[no-redef]
+
+
+class DependencyException(Exception):
+    @property
+    def message(self) -> str:
+        return "\n".join(
+            [
+                "Missing Requirements: %s" % (", ".join(self.dependencies),),
+                "To install run:",
+                "    pip install --upgrade --force %s" % (" ".join(self.dependencies),),
+                "",
+            ]
+        )
+
+    @property
+    def dependencies(self) -> Iterable[str]:
+        for i in self.args[0]:
+            yield '"' + i + '"'
+
+
+EXTRAS = set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra"))
+
+
+class Dependency(NamedTuple):
+    requirement: Requirement
+    must_be_installed: bool
+
+
+def _generic_dependencies() -> Iterable[Dependency]:
+    """Yield pairs (requirement, must_be_installed)."""
+    requirements = metadata.requires(DISTRIBUTION_NAME)
+    assert requirements is not None
+    for raw_requirement in requirements:
+        req = Requirement(raw_requirement)
+        # https://packaging.pypa.io/en/latest/markers.html#usage notes that
+        #   > Evaluating an extra marker with no environment is an error
+        # so we pass in a dummy empty extra value here.
+        must_be_installed = req.marker is None or req.marker.evaluate({"extra": ""})
+        yield Dependency(req, must_be_installed)
+
+
+def _dependencies_for_extra(extra: str) -> Iterable[Dependency]:
+    """Yield additional dependencies needed for a given `extra`."""
+    requirements = metadata.requires(DISTRIBUTION_NAME)
+    assert requirements is not None
+    for raw_requirement in requirements:
+        req = Requirement(raw_requirement)
+        # Exclude mandatory deps by only selecting deps needed with this extra.
+        if (
+            req.marker is not None
+            and req.marker.evaluate({"extra": extra})
+            and not req.marker.evaluate({"extra": ""})
+        ):
+            yield Dependency(req, True)
+
+
+def _not_installed(requirement: Requirement, extra: Optional[str] = None) -> str:
+    if extra:
+        return f"Need {requirement.name} for {extra}, but it is not installed"
+    else:
+        return f"Need {requirement.name}, but it is not installed"
+
+
+def _incorrect_version(
+    requirement: Requirement, got: str, extra: Optional[str] = None
+) -> str:
+    if extra:
+        return f"Need {requirement} for {extra}, but got {requirement.name}=={got}"
+    else:
+        return f"Need {requirement}, but got {requirement.name}=={got}"
+
+
+def check_requirements(extra: Optional[str] = None) -> None:
+    """Check Synapse's dependencies are present and correctly versioned.
+
+    If provided, `extra` must be the name of an pacakging extra (e.g. "saml2" in
+    `pip install matrix-synapse[saml2]`).
+
+    If `extra` is None, this function checks that
+    - all mandatory dependencies are installed and correctly versioned, and
+    - each optional dependency that's installed is correctly versioned.
+
+    If `extra` is not None, this function checks that
+    - the dependencies needed for that extra are installed and correctly versioned.
+
+    :raises DependencyException: if a dependency is missing or incorrectly versioned.
+    :raises ValueError: if this extra does not exist.
+    """
+    # First work out which dependencies are required, and which are optional.
+    if extra is None:
+        dependencies = _generic_dependencies()
+    elif extra in EXTRAS:
+        dependencies = _dependencies_for_extra(extra)
+    else:
+        raise ValueError(f"Synapse does not provide the feature '{extra}'")
+
+    deps_unfulfilled = []
+    errors = []
+
+    for (requirement, must_be_installed) in dependencies:
+        try:
+            dist: metadata.Distribution = metadata.distribution(requirement.name)
+        except metadata.PackageNotFoundError:
+            if must_be_installed:
+                deps_unfulfilled.append(requirement.name)
+                errors.append(_not_installed(requirement, extra))
+        else:
+            if not requirement.specifier.contains(dist.version):
+                deps_unfulfilled.append(requirement.name)
+                errors.append(_incorrect_version(requirement, dist.version, extra))
+
+    if deps_unfulfilled:
+        for err in errors:
+            logging.error(err)
+
+        raise DependencyException(deps_unfulfilled)
diff --git a/synctl b/synctl
index 0e54f4847b..1ab36949c7 100755
--- a/synctl
+++ b/synctl
@@ -37,6 +37,13 @@ YELLOW = "\x1b[1;33m"
 RED = "\x1b[1;31m"
 NORMAL = "\x1b[m"
 
+SYNCTL_CACHE_FACTOR_WARNING = """\
+Setting 'synctl_cache_factor' in the config is deprecated. Instead, please do
+one of the following:
+ - Either set the environment variable 'SYNAPSE_CACHE_FACTOR'
+ - or set 'caches.global_factor' in the homeserver config.
+--------------------------------------------------------------------------------"""
+
 
 def pid_running(pid):
     try:
@@ -228,6 +235,7 @@ def main():
     start_stop_synapse = True
 
     if cache_factor:
+        write(SYNCTL_CACHE_FACTOR_WARNING)
         os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
 
     cache_factors = config.get("synctl_cache_factors", {})
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4b53b6d40b..3e05789923 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -16,6 +16,8 @@ from unittest.mock import Mock
 
 import pymacaroons
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.auth import Auth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
@@ -26,8 +28,10 @@ from synapse.api.errors import (
     ResourceLimitError,
 )
 from synapse.appservice import ApplicationService
+from synapse.server import HomeServer
 from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import Requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import simple_async_mock
@@ -36,10 +40,10 @@ from tests.utils import mock_getRawHeaders
 
 
 class AuthTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
         self.store = Mock()
 
-        hs.get_datastore = Mock(return_value=self.store)
+        hs.datastores.main = self.store
         hs.get_auth_handler().store = self.store
         self.auth = Auth(hs)
 
@@ -67,7 +71,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
-        self.assertEquals(requester.user.to_string(), self.test_user)
+        self.assertEqual(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
         self.store.get_user_by_access_token = simple_async_mock(None)
@@ -105,7 +109,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
-        self.assertEquals(requester.user.to_string(), self.test_user)
+        self.assertEqual(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_appservice_valid_token_good_ip(self):
         from netaddr import IPSet
@@ -124,7 +128,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
-        self.assertEquals(requester.user.to_string(), self.test_user)
+        self.assertEqual(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_appservice_valid_token_bad_ip(self):
         from netaddr import IPSet
@@ -191,7 +195,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
-        self.assertEquals(
+        self.assertEqual(
             requester.user.to_string(), masquerading_user_id.decode("utf8")
         )
 
@@ -238,10 +242,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = self.get_success(self.auth.get_user_by_req(request))
-        self.assertEquals(
+        self.assertEqual(
             requester.user.to_string(), masquerading_user_id.decode("utf8")
         )
-        self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8"))
+        self.assertEqual(requester.device_id, masquerading_device_id.decode("utf8"))
 
     @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
     def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
@@ -271,8 +275,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         failure = self.get_failure(self.auth.get_user_by_req(request), AuthError)
-        self.assertEquals(failure.value.code, 400)
-        self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+        self.assertEqual(failure.value.code, 400)
+        self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
 
     def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
         self.store.get_user_by_access_token = simple_async_mock(
@@ -305,7 +309,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_success(self.auth.get_user_by_req(request))
-        self.assertEquals(self.store.insert_client_ip.call_count, 2)
+        self.assertEqual(self.store.insert_client_ip.call_count, 2)
 
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = simple_async_mock(
@@ -365,9 +369,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
 
         e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
-        self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
-        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
-        self.assertEquals(e.value.code, 403)
+        self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
+        self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEqual(e.value.code, 403)
 
         # Ensure does not throw an error
         self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
@@ -469,9 +473,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
         e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
-        self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
-        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
-        self.assertEquals(e.value.code, 403)
+        self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
+        self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEqual(e.value.code, 403)
 
     def test_hs_disabled_no_server_notices_user(self):
         """Check that 'hs_disabled_message' works correctly when there is no
@@ -484,9 +488,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
         e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
-        self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact)
-        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
-        self.assertEquals(e.value.code, 403)
+        self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
+        self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEqual(e.value.code, 403)
 
     def test_server_notices_mxid_special_cased(self):
         self.auth_blocking._hs_disabled = True
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index b7fc33dc94..8c3354ce3c 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -18,6 +18,7 @@
 from unittest.mock import patch
 
 import jsonschema
+from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields
 from synapse.api.errors import SynapseError
@@ -40,7 +41,7 @@ def MockEvent(**kwargs):
 class FilteringTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.filtering = hs.get_filtering()
-        self.datastore = hs.get_datastore()
+        self.datastore = hs.get_datastores().main
 
     def test_errors_on_invalid_filters(self):
         invalid_filters = [
@@ -327,6 +328,15 @@ class FilteringTestCase(unittest.HomeserverTestCase):
 
         self.assertFalse(Filter(self.hs, definition)._check(event))
 
+        # check it works with frozendicts too
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!secretbase:unknown",
+            content=frozendict({EventContentFields.LABELS: ["#fun"]}),
+        )
+        self.assertTrue(Filter(self.hs, definition)._check(event))
+
     def test_filter_not_labels(self):
         definition = {"org.matrix.not_labels": ["#fun"]}
         event = MockEvent(
@@ -364,7 +374,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         )
 
         results = self.get_success(user_filter.filter_presence(events=events))
-        self.assertEquals(events, results)
+        self.assertEqual(events, results)
 
     def test_filter_presence_no_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
@@ -388,7 +398,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         )
 
         results = self.get_success(user_filter.filter_presence(events=events))
-        self.assertEquals([], results)
+        self.assertEqual([], results)
 
     def test_filter_room_state_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
@@ -407,7 +417,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         )
 
         results = self.get_success(user_filter.filter_room_state(events=events))
-        self.assertEquals(events, results)
+        self.assertEqual(events, results)
 
     def test_filter_room_state_no_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
@@ -428,7 +438,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
         )
 
         results = self.get_success(user_filter.filter_room_state(events))
-        self.assertEquals([], results)
+        self.assertEqual([], results)
 
     def test_filter_rooms(self):
         definition = {
@@ -444,7 +454,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
 
         filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids))
 
-        self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
+        self.assertEqual(filtered_room_ids, ["!allowed:example.com"])
 
     @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_filter_relations(self):
@@ -486,7 +496,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
                     Filter(self.hs, definition)._check_event_relations(events)
                 )
             )
-        self.assertEquals(filtered_events, events[1:])
+        self.assertEqual(filtered_events, events[1:])
 
     def test_add_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
@@ -497,8 +507,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(filter_id, 0)
-        self.assertEquals(
+        self.assertEqual(filter_id, 0)
+        self.assertEqual(
             user_filter_json,
             (
                 self.get_success(
@@ -524,6 +534,6 @@ class FilteringTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(filter.get_filter_json(), user_filter_json)
+        self.assertEqual(filter.get_filter_json(), user_filter_json)
 
-        self.assertRegexpMatches(repr(filter), r"<FilterCollection \{.*\}>")
+        self.assertRegex(repr(filter), r"<FilterCollection \{.*\}>")
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index dcf0110c16..483d5463ad 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -8,25 +8,25 @@ from tests import unittest
 class TestRatelimiter(unittest.HomeserverTestCase):
     def test_allowed_via_can_do_action(self):
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id", _time_now_s=0)
         )
         self.assertTrue(allowed)
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id", _time_now_s=5)
         )
         self.assertFalse(allowed)
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id", _time_now_s=10)
         )
         self.assertTrue(allowed)
-        self.assertEquals(20.0, time_allowed)
+        self.assertEqual(20.0, time_allowed)
 
     def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
         appservice = ApplicationService(
@@ -39,25 +39,25 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertFalse(allowed)
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
-        self.assertEquals(20.0, time_allowed)
+        self.assertEqual(20.0, time_allowed)
 
     def test_allowed_appservice_via_can_requester_do_action(self):
         appservice = ApplicationService(
@@ -70,29 +70,29 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
         )
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
-        self.assertEquals(-1, time_allowed)
+        self.assertEqual(-1, time_allowed)
 
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertTrue(allowed)
-        self.assertEquals(-1, time_allowed)
+        self.assertEqual(-1, time_allowed)
 
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
-        self.assertEquals(-1, time_allowed)
+        self.assertEqual(-1, time_allowed)
 
     def test_allowed_via_ratelimit(self):
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
         )
 
         # Shouldn't raise
@@ -116,7 +116,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
         )
 
         # First attempt should be allowed
@@ -162,7 +162,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
         )
 
         # First attempt should be allowed
@@ -190,7 +190,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
 
     def test_pruning(self):
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=1
         )
         self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@@ -208,7 +208,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         """Test that users that have ratelimiting disabled in the DB aren't
         ratelimited.
         """
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         user_id = "@user:test"
         requester = create_requester(user_id)
@@ -233,7 +233,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
 
     def test_multiple_actions(self):
         limiter = Ratelimiter(
-            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=3
+            store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
         )
         # Test that 4 actions aren't allowed with a maximum burst of 3.
         allowed, time_allowed = self.get_success_or_raise(
@@ -246,7 +246,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
             limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0)
         )
         self.assertTrue(allowed)
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         # Test that, after doing these 3 actions, we can't do any more action without
         # waiting.
@@ -254,7 +254,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
             limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0)
         )
         self.assertFalse(allowed)
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         # Test that after waiting we can do only 1 action.
         allowed, time_allowed = self.get_success_or_raise(
@@ -269,7 +269,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         self.assertTrue(allowed)
         # The time allowed is the current time because we could still repeat the action
         # once.
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         allowed, time_allowed = self.get_success_or_raise(
             limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10)
@@ -277,7 +277,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         self.assertFalse(allowed)
         # The time allowed doesn't change despite allowed being False because, while we
         # don't allow 2 actions, we could still do 1.
-        self.assertEquals(10.0, time_allowed)
+        self.assertEqual(10.0, time_allowed)
 
         # Test that after waiting a bit more we can do 2 actions.
         allowed, time_allowed = self.get_success_or_raise(
@@ -286,4 +286,4 @@ class TestRatelimiter(unittest.HomeserverTestCase):
         self.assertTrue(allowed)
         # The time allowed is the current time because we could still repeat the action
         # once.
-        self.assertEquals(20.0, time_allowed)
+        self.assertEqual(20.0, time_allowed)
diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py
index 19eb4c79d0..df731eb599 100644
--- a/tests/app/test_phone_stats_home.py
+++ b/tests/app/test_phone_stats_home.py
@@ -32,7 +32,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
         self.helper.send(room_id, "message", tok=access_token)
 
         # Check the R30 results do not count that user.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 0})
 
         # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
@@ -40,7 +40,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
         self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
 
         # (Make sure the user isn't somehow counted by this point.)
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 0})
 
         # Send a message (this counts as activity)
@@ -51,21 +51,21 @@ class PhoneHomeTestCase(HomeserverTestCase):
         self.reactor.advance(2 * 60 * 60)
 
         # *Now* the user is counted.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 1, "unknown": 1})
 
         # Advance 29 days. The user has now not posted for 29 days.
         self.reactor.advance(29 * ONE_DAY_IN_SECONDS)
 
         # The user is still counted.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 1, "unknown": 1})
 
         # Advance another day. The user has now not posted for 30 days.
         self.reactor.advance(ONE_DAY_IN_SECONDS)
 
         # The user is now no longer counted in R30.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 0})
 
     def test_r30_minimum_usage_using_default_config(self):
@@ -84,7 +84,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
         self.helper.send(room_id, "message", tok=access_token)
 
         # Check the R30 results do not count that user.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 0})
 
         # Advance 30 days (+ 1 second, because strict inequality causes issues if we are
@@ -92,7 +92,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
         self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1)
 
         # (Make sure the user isn't somehow counted by this point.)
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 0})
 
         # Send a message (this counts as activity)
@@ -103,14 +103,14 @@ class PhoneHomeTestCase(HomeserverTestCase):
         self.reactor.advance(2 * 60 * 60)
 
         # *Now* the user is counted.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 1, "unknown": 1})
 
         # Advance 27 days. The user has now not posted for 27 days.
         self.reactor.advance(27 * ONE_DAY_IN_SECONDS)
 
         # The user is still counted.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 1, "unknown": 1})
 
         # Advance another day. The user has now not posted for 28 days.
@@ -119,7 +119,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
         # The user is now no longer counted in R30.
         # (This is because the user_ips table has been pruned, which by default
         # only preserves the last 28 days of entries.)
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 0})
 
     def test_r30_user_must_be_retained_for_at_least_a_month(self):
@@ -135,7 +135,7 @@ class PhoneHomeTestCase(HomeserverTestCase):
         self.helper.send(room_id, "message", tok=access_token)
 
         # Check the user does not contribute to R30 yet.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 0})
 
         for _ in range(30):
@@ -144,14 +144,16 @@ class PhoneHomeTestCase(HomeserverTestCase):
             self.helper.send(room_id, "I'm still here", tok=access_token)
 
             # Notice that the user *still* does not contribute to R30!
-            r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+            r30_results = self.get_success(
+                self.hs.get_datastores().main.count_r30_users()
+            )
             self.assertEqual(r30_results, {"all": 0})
 
         self.reactor.advance(ONE_DAY_IN_SECONDS)
         self.helper.send(room_id, "Still here!", tok=access_token)
 
         # *Now* the user appears in R30.
-        r30_results = self.get_success(self.hs.get_datastore().count_r30_users())
+        r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users())
         self.assertEqual(r30_results, {"all": 1, "unknown": 1})
 
 
@@ -196,7 +198,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
         # (user_daily_visits is updated every 5 minutes using a looping call.)
         self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Check the R30 results do not count that user.
         r30_results = self.get_success(store.count_r30v2_users())
@@ -275,7 +277,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
         # (user_daily_visits is updated every 5 minutes using a looping call.)
         self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Check the user does not contribute to R30 yet.
         r30_results = self.get_success(store.count_r30v2_users())
@@ -347,7 +349,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
         # (user_daily_visits is updated every 5 minutes using a looping call.)
         self.reactor.advance(FIVE_MINUTES_IN_SECONDS)
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Check that the user does not contribute to R30v2, even though it's been
         # more than 30 days since registration.
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 8fb6687f89..1cbb059357 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -68,8 +68,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
             events=events,
             ephemeral=[],
             to_device_messages=[],  # txn made and saved
+            one_time_key_counts={},
+            unused_fallback_keys={},
         )
-        self.assertEquals(0, len(self.txnctrl.recoverers))  # no recoverer made
+        self.assertEqual(0, len(self.txnctrl.recoverers))  # no recoverer made
         txn.complete.assert_called_once_with(self.store)  # txn completed
 
     def test_single_service_down(self):
@@ -92,9 +94,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
             events=events,
             ephemeral=[],
             to_device_messages=[],  # txn made and saved
+            one_time_key_counts={},
+            unused_fallback_keys={},
         )
-        self.assertEquals(0, txn.send.call_count)  # txn not sent though
-        self.assertEquals(0, txn.complete.call_count)  # or completed
+        self.assertEqual(0, txn.send.call_count)  # txn not sent though
+        self.assertEqual(0, txn.complete.call_count)  # or completed
 
     def test_single_service_up_txn_not_sent(self):
         # Test: The AS is up and the txn is not sent. A Recoverer is made and
@@ -114,12 +118,17 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
 
         self.store.create_appservice_txn.assert_called_once_with(
-            service=service, events=events, ephemeral=[], to_device_messages=[]
+            service=service,
+            events=events,
+            ephemeral=[],
+            to_device_messages=[],
+            one_time_key_counts={},
+            unused_fallback_keys={},
         )
-        self.assertEquals(1, self.recoverer_fn.call_count)  # recoverer made
-        self.assertEquals(1, self.recoverer.recover.call_count)  # and invoked
-        self.assertEquals(1, len(self.txnctrl.recoverers))  # and stored
-        self.assertEquals(0, txn.complete.call_count)  # txn not completed
+        self.assertEqual(1, self.recoverer_fn.call_count)  # recoverer made
+        self.assertEqual(1, self.recoverer.recover.call_count)  # and invoked
+        self.assertEqual(1, len(self.txnctrl.recoverers))  # and stored
+        self.assertEqual(0, txn.complete.call_count)  # txn not completed
         self.store.set_appservice_state.assert_called_once_with(
             service, ApplicationServiceState.DOWN  # service marked as down
         )
@@ -152,17 +161,17 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
 
         self.recoverer.recover()
         # shouldn't have called anything prior to waiting for exp backoff
-        self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
+        self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
         txn.send = simple_async_mock(True)
         txn.complete = simple_async_mock(None)
         # wait for exp backoff
         self.clock.advance_time(2)
-        self.assertEquals(1, txn.send.call_count)
-        self.assertEquals(1, txn.complete.call_count)
+        self.assertEqual(1, txn.send.call_count)
+        self.assertEqual(1, txn.complete.call_count)
         # 2 because it needs to get None to know there are no more txns
-        self.assertEquals(2, self.store.get_oldest_unsent_txn.call_count)
+        self.assertEqual(2, self.store.get_oldest_unsent_txn.call_count)
         self.callback.assert_called_once_with(self.recoverer)
-        self.assertEquals(self.recoverer.service, self.service)
+        self.assertEqual(self.recoverer.service, self.service)
 
     def test_recover_retry_txn(self):
         txn = Mock()
@@ -178,26 +187,26 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
 
         self.recoverer.recover()
-        self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
+        self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
         txn.send = simple_async_mock(False)
         txn.complete = simple_async_mock(None)
         self.clock.advance_time(2)
-        self.assertEquals(1, txn.send.call_count)
-        self.assertEquals(0, txn.complete.call_count)
-        self.assertEquals(0, self.callback.call_count)
+        self.assertEqual(1, txn.send.call_count)
+        self.assertEqual(0, txn.complete.call_count)
+        self.assertEqual(0, self.callback.call_count)
         self.clock.advance_time(4)
-        self.assertEquals(2, txn.send.call_count)
-        self.assertEquals(0, txn.complete.call_count)
-        self.assertEquals(0, self.callback.call_count)
+        self.assertEqual(2, txn.send.call_count)
+        self.assertEqual(0, txn.complete.call_count)
+        self.assertEqual(0, self.callback.call_count)
         self.clock.advance_time(8)
-        self.assertEquals(3, txn.send.call_count)
-        self.assertEquals(0, txn.complete.call_count)
-        self.assertEquals(0, self.callback.call_count)
+        self.assertEqual(3, txn.send.call_count)
+        self.assertEqual(0, txn.complete.call_count)
+        self.assertEqual(0, self.callback.call_count)
         txn.send = simple_async_mock(True)  # successfully send the txn
         pop_txn = True  # returns the txn the first time, then no more.
         self.clock.advance_time(16)
-        self.assertEquals(1, txn.send.call_count)  # new mock reset call count
-        self.assertEquals(1, txn.complete.call_count)
+        self.assertEqual(1, txn.send.call_count)  # new mock reset call count
+        self.assertEqual(1, txn.complete.call_count)
         self.callback.assert_called_once_with(self.recoverer)
 
 
@@ -216,7 +225,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         service = Mock(id=4)
         event = Mock()
         self.scheduler.enqueue_for_appservice(service, events=[event])
-        self.txn_ctrl.send.assert_called_once_with(service, [event], [], [])
+        self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
 
     def test_send_single_event_with_queue(self):
         d = defer.Deferred()
@@ -231,12 +240,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         # (call enqueue_for_appservice multiple times deliberately)
         self.scheduler.enqueue_for_appservice(service, events=[event2])
         self.scheduler.enqueue_for_appservice(service, events=[event3])
-        self.txn_ctrl.send.assert_called_with(service, [event], [], [])
-        self.assertEquals(1, self.txn_ctrl.send.call_count)
+        self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
+        self.assertEqual(1, self.txn_ctrl.send.call_count)
         # Resolve the send event: expect the queued events to be sent
         d.callback(service)
-        self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], [])
-        self.assertEquals(2, self.txn_ctrl.send.call_count)
+        self.txn_ctrl.send.assert_called_with(
+            service, [event2, event3], [], [], None, None
+        )
+        self.assertEqual(2, self.txn_ctrl.send.call_count)
 
     def test_multiple_service_queues(self):
         # Tests that each service has its own queue, and that they don't block
@@ -261,16 +272,16 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         # send events for different ASes and make sure they are sent
         self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
         self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
-        self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [])
+        self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
         self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
         self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
-        self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [])
+        self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
 
         # make sure callbacks for a service only send queued events for THAT
         # service
         srv_2_defer.callback(srv2)
-        self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [])
-        self.assertEquals(3, self.txn_ctrl.send.call_count)
+        self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
+        self.assertEqual(3, self.txn_ctrl.send.call_count)
 
     def test_send_large_txns(self):
         srv_1_defer = defer.Deferred()
@@ -288,28 +299,38 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
             self.scheduler.enqueue_for_appservice(service, [event], [])
 
         # Expect the first event to be sent immediately.
-        self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], [])
+        self.txn_ctrl.send.assert_called_with(
+            service, [event_list[0]], [], [], None, None
+        )
         srv_1_defer.callback(service)
         # Then send the next 100 events
-        self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], [])
+        self.txn_ctrl.send.assert_called_with(
+            service, event_list[1:101], [], [], None, None
+        )
         srv_2_defer.callback(service)
         # Then the final 99 events
-        self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], [])
-        self.assertEquals(3, self.txn_ctrl.send.call_count)
+        self.txn_ctrl.send.assert_called_with(
+            service, event_list[101:], [], [], None, None
+        )
+        self.assertEqual(3, self.txn_ctrl.send.call_count)
 
     def test_send_single_ephemeral_no_queue(self):
         # Expect the event to be sent immediately.
         service = Mock(id=4, name="service")
         event_list = [Mock(name="event")]
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
-        self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
+        self.txn_ctrl.send.assert_called_once_with(
+            service, [], event_list, [], None, None
+        )
 
     def test_send_multiple_ephemeral_no_queue(self):
         # Expect the event to be sent immediately.
         service = Mock(id=4, name="service")
         event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
-        self.txn_ctrl.send.assert_called_once_with(service, [], event_list, [])
+        self.txn_ctrl.send.assert_called_once_with(
+            service, [], event_list, [], None, None
+        )
 
     def test_send_single_ephemeral_with_queue(self):
         d = defer.Deferred()
@@ -324,15 +345,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         # Send more events: expect send() to NOT be called multiple times.
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
-        self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [])
-        self.assertEquals(1, self.txn_ctrl.send.call_count)
+        self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
+        self.assertEqual(1, self.txn_ctrl.send.call_count)
         # Resolve txn_ctrl.send
         d.callback(service)
         # Expect the queued events to be sent
         self.txn_ctrl.send.assert_called_with(
-            service, [], event_list_2 + event_list_3, []
+            service, [], event_list_2 + event_list_3, [], None, None
         )
-        self.assertEquals(2, self.txn_ctrl.send.call_count)
+        self.assertEqual(2, self.txn_ctrl.send.call_count)
 
     def test_send_large_txns_ephemeral(self):
         d = defer.Deferred()
@@ -343,7 +364,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
         second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
         event_list = first_chunk + second_chunk
         self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
-        self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, [])
+        self.txn_ctrl.send.assert_called_once_with(
+            service, [], first_chunk, [], None, None
+        )
         d.callback(service)
-        self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [])
-        self.assertEquals(2, self.txn_ctrl.send.call_count)
+        self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
+        self.assertEqual(2, self.txn_ctrl.send.call_count)
diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py
index a72a0103d3..694020fbef 100644
--- a/tests/crypto/test_event_signing.py
+++ b/tests/crypto/test_event_signing.py
@@ -63,14 +63,14 @@ class EventSigningTestCase(unittest.TestCase):
 
         self.assertTrue(hasattr(event, "hashes"))
         self.assertIn("sha256", event.hashes)
-        self.assertEquals(
+        self.assertEqual(
             event.hashes["sha256"], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
         )
 
         self.assertTrue(hasattr(event, "signatures"))
         self.assertIn(HOSTNAME, event.signatures)
         self.assertIn(KEY_NAME, event.signatures["domain"])
-        self.assertEquals(
+        self.assertEqual(
             event.signatures[HOSTNAME][KEY_NAME],
             "2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+"
             "aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA",
@@ -97,14 +97,14 @@ class EventSigningTestCase(unittest.TestCase):
 
         self.assertTrue(hasattr(event, "hashes"))
         self.assertIn("sha256", event.hashes)
-        self.assertEquals(
+        self.assertEqual(
             event.hashes["sha256"], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
         )
 
         self.assertTrue(hasattr(event, "signatures"))
         self.assertIn(HOSTNAME, event.signatures)
         self.assertIn(KEY_NAME, event.signatures["domain"])
-        self.assertEquals(
+        self.assertEqual(
             event.signatures[HOSTNAME][KEY_NAME],
             "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
             "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA",
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 17a9fb63a1..d00ef24ca8 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -76,7 +76,7 @@ class FakeRequest:
 @logcontext_clean
 class KeyringTestCase(unittest.HomeserverTestCase):
     def check_context(self, val, expected):
-        self.assertEquals(getattr(current_context(), "request", None), expected)
+        self.assertEqual(getattr(current_context(), "request", None), expected)
         return val
 
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
@@ -96,7 +96,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         async def first_lookup_fetch(
             server_name: str, key_ids: List[str], minimum_valid_until_ts: int
         ) -> Dict[str, FetchKeyResult]:
-            # self.assertEquals(current_context().request.id, "context_11")
+            # self.assertEqual(current_context().request.id, "context_11")
             self.assertEqual(server_name, "server10")
             self.assertEqual(key_ids, [get_key_id(key1)])
             self.assertEqual(minimum_valid_until_ts, 0)
@@ -137,7 +137,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         async def second_lookup_fetch(
             server_name: str, key_ids: List[str], minimum_valid_until_ts: int
         ) -> Dict[str, FetchKeyResult]:
-            # self.assertEquals(current_context().request.id, "context_12")
+            # self.assertEqual(current_context().request.id, "context_12")
             return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
 
         mock_fetcher.get_keys.reset_mock()
@@ -179,7 +179,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key(1)
-        r = self.hs.get_datastore().store_server_verify_keys(
+        r = self.hs.get_datastores().main.store_server_verify_keys(
             "server9",
             time.time() * 1000,
             [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -272,7 +272,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         )
 
         key1 = signedjson.key.generate_signing_key(1)
-        r = self.hs.get_datastore().store_server_verify_keys(
+        r = self.hs.get_datastores().main.store_server_verify_keys(
             "server9",
             time.time() * 1000,
             [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
@@ -448,7 +448,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
-            self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+            self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
         )
         res = key_json[lookup_triplet]
         self.assertEqual(len(res), 1)
@@ -564,7 +564,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
-            self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+            self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
         )
         res = key_json[lookup_triplet]
         self.assertEqual(len(res), 1)
@@ -683,7 +683,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
-            self.hs.get_datastore().get_server_keys_json([lookup_triplet])
+            self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
         )
         res = key_json[lookup_triplet]
         self.assertEqual(len(res), 1)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index ca27388ae8..defbc68c18 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -28,7 +28,7 @@ class TestEventContext(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
 
         self.user_id = self.register_user("u1", "pass")
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 1dea09e480..45e3395b33 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -395,7 +395,7 @@ class SerializeEventTestCase(unittest.TestCase):
         return serialize_event(ev, 1479807801915, only_event_fields=fields)
 
     def test_event_fields_works_with_keys(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
             ),
@@ -403,7 +403,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_works_with_nested_keys(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(
                     sender="@alice:localhost",
@@ -416,7 +416,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_works_with_dot_keys(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(
                     sender="@alice:localhost",
@@ -429,7 +429,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_works_with_nested_dot_keys(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(
                     sender="@alice:localhost",
@@ -445,7 +445,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_nops_with_unknown_keys(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(
                     sender="@alice:localhost",
@@ -458,7 +458,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_nops_with_non_dict_keys(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(
                     sender="@alice:localhost",
@@ -471,7 +471,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_nops_with_array_keys(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(
                     sender="@alice:localhost",
@@ -484,7 +484,7 @@ class SerializeEventTestCase(unittest.TestCase):
         )
 
     def test_event_fields_all_fields_if_empty(self):
-        self.assertEquals(
+        self.assertEqual(
             self.serialize(
                 MockEvent(
                     type="foo",
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index e40ef95874..9f1115dd23 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -50,19 +50,19 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         complexity = channel.json_body["v1"]
         self.assertTrue(complexity > 0, complexity)
 
         # Artificially raise the complexity
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
 
         # Get the room complexity again -- make sure it's our artificial value
         channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         complexity = channel.json_body["v1"]
         self.assertEqual(complexity, 1.23)
 
@@ -149,7 +149,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         )
 
         # Artificially raise the complexity
-        self.hs.get_datastore().get_current_state_event_counts = (
+        self.hs.get_datastores().main.get_current_state_event_counts = (
             lambda x: make_awaitable(600)
         )
 
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index f0aa8ed9db..2873b4d430 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -64,7 +64,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             Dictionary of { event_id: str, stream_ordering: int }
         """
         event_id, stream_ordering = self.get_success(
-            self.hs.get_datastore().db_pool.execute(
+            self.hs.get_datastores().main.db_pool.execute(
                 "test:get_destination_rooms",
                 None,
                 """
@@ -125,7 +125,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         self.pump()
 
         lsso_1 = self.get_success(
-            self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+            self.hs.get_datastores().main.get_destination_last_successful_stream_ordering(
                 "host2"
             )
         )
@@ -141,7 +141,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         event_id_2 = self.helper.send(room, "rabbits!", tok=u1_token)["event_id"]
 
         lsso_2 = self.get_success(
-            self.hs.get_datastore().get_destination_last_successful_stream_ordering(
+            self.hs.get_datastores().main.get_destination_last_successful_stream_ordering(
                 "host2"
             )
         )
@@ -216,7 +216,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
 
         # let's also clear any backoffs
         self.get_success(
-            self.hs.get_datastore().set_destination_retry_timings("host2", None, 0, 0)
+            self.hs.get_datastores().main.set_destination_retry_timings(
+                "host2", None, 0, 0
+            )
         )
 
         # bring the remote online and clear the received pdu list
@@ -296,13 +298,13 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
 
         # destination_rooms should already be populated, but let us pretend that we already
         # sent (successfully) up to and including event id 2
-        event_2 = self.get_success(self.hs.get_datastore().get_event(event_id_2))
+        event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2))
 
         # also fetch event 5 so we know its last_successful_stream_ordering later
-        event_5 = self.get_success(self.hs.get_datastore().get_event(event_id_5))
+        event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
 
         self.get_success(
-            self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+            self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
                 "host2", event_2.internal_metadata.stream_ordering
             )
         )
@@ -359,7 +361,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # ASSERT:
         # - All servers are up to date so none should have outstanding catch-up
         outstanding_when_successful = self.get_success(
-            self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+            self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None)
         )
         self.assertEqual(outstanding_when_successful, [])
 
@@ -370,7 +372,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # - Mark zzzerver as being backed-off from
         now = self.clock.time_msec()
         self.get_success(
-            self.hs.get_datastore().set_destination_retry_timings(
+            self.hs.get_datastores().main.set_destination_retry_timings(
                 "zzzerver", now, now, 24 * 60 * 60 * 1000  # retry in 1 day
             )
         )
@@ -382,14 +384,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # - all remotes are outstanding
         # - they are returned in batches of 25, in order
         outstanding_1 = self.get_success(
-            self.hs.get_datastore().get_catch_up_outstanding_destinations(None)
+            self.hs.get_datastores().main.get_catch_up_outstanding_destinations(None)
         )
 
         self.assertEqual(len(outstanding_1), 25)
         self.assertEqual(outstanding_1, server_names[0:25])
 
         outstanding_2 = self.get_success(
-            self.hs.get_datastore().get_catch_up_outstanding_destinations(
+            self.hs.get_datastores().main.get_catch_up_outstanding_destinations(
                 outstanding_1[-1]
             )
         )
@@ -457,7 +459,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         )
 
         self.get_success(
-            self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+            self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
                 "host2", event_1.internal_metadata.stream_ordering
             )
         )
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b2376e2db9..60e0c31f43 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -176,7 +176,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         def get_users_who_share_room_with_user(user_id):
             return defer.succeed({"@user2:host2"})
 
-        hs.get_datastore().get_users_who_share_room_with_user = (
+        hs.get_datastores().main.get_users_who_share_room_with_user = (
             get_users_who_share_room_with_user
         )
 
@@ -395,7 +395,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         # run the prune job
         self.reactor.advance(10)
         self.get_success(
-            self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+            self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1)
         )
 
         # recover the server
@@ -445,7 +445,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         # run the prune job
         self.reactor.advance(10)
         self.get_success(
-            self.hs.get_datastore()._prune_old_outbound_device_pokes(prune_age=1)
+            self.hs.get_datastores().main._prune_old_outbound_device_pokes(prune_age=1)
         )
 
         # recover the server
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index d084919ef7..30e7e5093a 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -59,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
             "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
             query_content,
         )
-        self.assertEquals(400, channel.code, channel.result)
+        self.assertEqual(400, channel.code, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
 
 
@@ -125,7 +125,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
         channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         self.assertEqual(
             channel.json_body["room_version"],
@@ -157,7 +157,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
         channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
-        self.assertEquals(403, channel.code, channel.result)
+        self.assertEqual(403, channel.code, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
 
 
@@ -189,7 +189,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
             f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
             f"?ver={DEFAULT_ROOM_VERSION}",
         )
-        self.assertEquals(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, 200, channel.json_body)
         return channel.json_body
 
     def test_send_join(self):
@@ -209,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
             f"/_matrix/federation/v2/send_join/{self._room_id}/x",
             content=join_event_dict,
         )
-        self.assertEquals(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, 200, channel.json_body)
 
         # we should get complete room state back
         returned_state = [
@@ -266,7 +266,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
             f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
             content=join_event_dict,
         )
-        self.assertEquals(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, 200, channel.json_body)
 
         # expect a reduced room state
         returned_state = [
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 686f42ab48..648a01618e 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -169,7 +169,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase):
             self.assertIn(event_type, expected_room_state)
 
             # Check the state content matches
-            self.assertEquals(
+            self.assertEqual(
                 expected_room_state[event_type]["content"], event["content"]
             )
 
@@ -198,7 +198,7 @@ class FederationKnockingTestCase(
     ]
 
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
 
         # We're not going to be properly signing events as our remote homeserver is fake,
         # therefore disable event signature checks.
@@ -256,7 +256,7 @@ class FederationKnockingTestCase(
                 RoomVersions.V7.identifier,
             ),
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # Note: We don't expect the knock membership event to be sent over federation as
         # part of the stripped room state, as the knocking homeserver already has that
@@ -266,11 +266,11 @@ class FederationKnockingTestCase(
         knock_event = channel.json_body["event"]
 
         # Check that the event has things we expect in it
-        self.assertEquals(knock_event["room_id"], room_id)
-        self.assertEquals(knock_event["sender"], fake_knocking_user_id)
-        self.assertEquals(knock_event["state_key"], fake_knocking_user_id)
-        self.assertEquals(knock_event["type"], EventTypes.Member)
-        self.assertEquals(knock_event["content"]["membership"], Membership.KNOCK)
+        self.assertEqual(knock_event["room_id"], room_id)
+        self.assertEqual(knock_event["sender"], fake_knocking_user_id)
+        self.assertEqual(knock_event["state_key"], fake_knocking_user_id)
+        self.assertEqual(knock_event["type"], EventTypes.Member)
+        self.assertEqual(knock_event["content"]["membership"], Membership.KNOCK)
 
         # Turn the event json dict into a proper event.
         # We won't sign it properly, but that's OK as we stub out event auth in `prepare`
@@ -294,7 +294,7 @@ class FederationKnockingTestCase(
             % (room_id, signed_knock_event.event_id),
             signed_knock_event_json,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # Check that we got the stripped room state in return
         room_state_events = channel.json_body["knock_state_events"]
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index eb62addda8..5f001c33b0 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from tests import unittest
-from tests.unittest import override_config
+from tests.unittest import DEBUG, override_config
 
 
 class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
@@ -26,7 +26,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
             "GET",
             "/_matrix/federation/v1/publicRooms",
         )
-        self.assertEquals(403, channel.code)
+        self.assertEqual(403, channel.code)
 
     @override_config({"allow_public_rooms_over_federation": True})
     def test_open_public_room_list_over_federation(self):
@@ -37,4 +37,22 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
             "GET",
             "/_matrix/federation/v1/publicRooms",
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
+
+    @DEBUG
+    def test_edu_debugging_doesnt_explode(self):
+        """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled.
+
+        Remove this when we strip out issue_8631_logger.
+        """
+        channel = self.make_signed_federation_request(
+            "PUT",
+            "/_matrix/federation/v1/send/txn_id_1234/",
+            content={
+                "edus": [
+                    {"edu_type": "m.device_list_update", "content": {"foo": "bar"}}
+                ],
+                "pdus": [],
+            },
+        )
+        self.assertEqual(200, channel.code)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index fe57ff2671..072e6bbcdd 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 import synapse.storage
-from synapse.appservice import ApplicationService
+from synapse.appservice import (
+    ApplicationService,
+    TransactionOneTimeKeyCounts,
+    TransactionUnusedFallbackKeys,
+)
 from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.rest.client import login, receipts, room, sendtodevice
+from synapse.rest.client import login, receipts, register, room, sendtodevice
+from synapse.server import HomeServer
 from synapse.types import RoomStreamToken
+from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
 from tests.test_utils import make_awaitable, simple_async_mock
+from tests.unittest import override_config
 from tests.utils import MockClock
 
 
@@ -38,7 +46,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_as_api = Mock()
         self.mock_scheduler = Mock()
         hs = Mock()
-        hs.get_datastore.return_value = self.mock_store
+        hs.get_datastores.return_value = Mock(main=self.mock_store)
         self.mock_store.get_received_ts.return_value = make_awaitable(0)
         self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
         self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
@@ -139,8 +147,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_as_api.query_alias.assert_called_once_with(
             interested_service, room_alias_str
         )
-        self.assertEquals(result.room_id, room_id)
-        self.assertEquals(result.servers, servers)
+        self.assertEqual(result.room_id, room_id)
+        self.assertEqual(result.servers, servers)
 
     def test_get_3pe_protocols_no_appservices(self):
         self.mock_store.get_app_services.return_value = []
@@ -148,7 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
         )
         self.mock_as_api.get_3pe_protocol.assert_not_called()
-        self.assertEquals(response, {})
+        self.assertEqual(response, {})
 
     def test_get_3pe_protocols_no_protocols(self):
         service = self._mkservice(False, [])
@@ -157,7 +165,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
         self.mock_as_api.get_3pe_protocol.assert_not_called()
-        self.assertEquals(response, {})
+        self.assertEqual(response, {})
 
     def test_get_3pe_protocols_protocol_no_response(self):
         service = self._mkservice(False, ["my-protocol"])
@@ -169,7 +177,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_as_api.get_3pe_protocol.assert_called_once_with(
             service, "my-protocol"
         )
-        self.assertEquals(response, {})
+        self.assertEqual(response, {})
 
     def test_get_3pe_protocols_select_one_protocol(self):
         service = self._mkservice(False, ["my-protocol"])
@@ -183,7 +191,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_as_api.get_3pe_protocol.assert_called_once_with(
             service, "my-protocol"
         )
-        self.assertEquals(
+        self.assertEqual(
             response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
         )
 
@@ -199,7 +207,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_as_api.get_3pe_protocol.assert_called_once_with(
             service, "my-protocol"
         )
-        self.assertEquals(
+        self.assertEqual(
             response, {"my-protocol": {"x-protocol-data": 42, "instances": []}}
         )
 
@@ -214,7 +222,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
         self.mock_as_api.get_3pe_protocol.assert_called()
-        self.assertEquals(
+        self.assertEqual(
             response,
             {
                 "my-protocol": {"x-protocol-data": 42, "instances": []},
@@ -246,7 +254,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
         # It's expected that the second service's data doesn't appear in the response
-        self.assertEquals(
+        self.assertEqual(
             response,
             {
                 "my-protocol": {
@@ -355,7 +363,9 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
 
         # Mock out application services, and allow defining our own in tests
         self._services: List[ApplicationService] = []
-        self.hs.get_datastore().get_app_services = Mock(return_value=self._services)
+        self.hs.get_datastores().main.get_app_services = Mock(
+            return_value=self._services
+        )
 
         # A user on the homeserver.
         self.local_user_device_id = "local_device"
@@ -426,7 +436,14 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         #
         # The uninterested application service should not have been notified at all.
         self.send_mock.assert_called_once()
-        service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0]
+        (
+            service,
+            _events,
+            _ephemeral,
+            to_device_messages,
+            _otks,
+            _fbks,
+        ) = self.send_mock.call_args[0]
 
         # Assert that this was the same to-device message that local_user sent
         self.assertEqual(service, interested_appservice)
@@ -494,7 +511,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         # Create a fake device per message. We can't send to-device messages to
         # a device that doesn't exist.
         self.get_success(
-            self.hs.get_datastore().db_pool.simple_insert_many(
+            self.hs.get_datastores().main.db_pool.simple_insert_many(
                 desc="test_application_services_receive_burst_of_to_device",
                 table="devices",
                 keys=("user_id", "device_id"),
@@ -510,7 +527,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
 
         # Seed the device_inbox table with our fake messages
         self.get_success(
-            self.hs.get_datastore().add_messages_to_device_inbox(messages, {})
+            self.hs.get_datastores().main.add_messages_to_device_inbox(messages, {})
         )
 
         # Now have local_user send a final to-device message to exclusive_as_user. All unsent
@@ -538,7 +555,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         service_id_to_message_count: Dict[str, int] = {}
 
         for call in self.send_mock.call_args_list:
-            service, _events, _ephemeral, to_device_messages = call[0]
+            service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
 
             # Check that this was made to an interested service
             self.assertIn(service, interested_appservices)
@@ -580,3 +597,174 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         self._services.append(appservice)
 
         return appservice
+
+
+class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
+    # Argument indices for pulling out arguments from a `send_mock`.
+    ARG_OTK_COUNTS = 4
+    ARG_FALLBACK_KEYS = 5
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        register.register_servlets,
+        room.register_servlets,
+        sendtodevice.register_servlets,
+        receipts.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
+        # we can track what's going out
+        self.send_mock = simple_async_mock()
+        hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[assignment]  # We assign to a method.
+
+        # Define an application service for the tests
+        self._service_token = "VERYSECRET"
+        self._service = ApplicationService(
+            self._service_token,
+            "as1.invalid",
+            "as1",
+            "@as.sender:test",
+            namespaces={
+                "users": [
+                    {"regex": "@_as_.*:test", "exclusive": True},
+                    {"regex": "@as.sender:test", "exclusive": True},
+                ]
+            },
+            msc3202_transaction_extensions=True,
+        )
+        self.hs.get_datastores().main.services_cache = [self._service]
+
+        # Register some appservice users
+        self._sender_user, self._sender_device = self.register_appservice_user(
+            "as.sender", self._service_token
+        )
+        self._namespaced_user, self._namespaced_device = self.register_appservice_user(
+            "_as_user1", self._service_token
+        )
+
+        # Register a real user as well.
+        self._real_user = self.register_user("real.user", "meow")
+        self._real_user_token = self.login("real.user", "meow")
+
+    async def _add_otks_for_device(
+        self, user_id: str, device_id: str, otk_count: int
+    ) -> None:
+        """
+        Add some dummy keys. It doesn't matter if they're not a real algorithm;
+        that should be opaque to the server anyway.
+        """
+        await self.hs.get_datastores().main.add_e2e_one_time_keys(
+            user_id,
+            device_id,
+            self.clock.time_msec(),
+            [("algo", f"k{i}", "{}") for i in range(otk_count)],
+        )
+
+    async def _add_fallback_key_for_device(
+        self, user_id: str, device_id: str, used: bool
+    ) -> None:
+        """
+        Adds a fake fallback key to a device, optionally marking it as used
+        right away.
+        """
+        store = self.hs.get_datastores().main
+        await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"})
+        if used is True:
+            # Mark the key as used
+            await store.db_pool.simple_update_one(
+                table="e2e_fallback_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": "algo",
+                    "key_id": "fk",
+                },
+                updatevalues={"used": True},
+                desc="_get_fallback_key_set_used",
+            )
+
+    def _set_up_devices_and_a_room(self) -> str:
+        """
+        Helper to set up devices for all the users
+        and a room for the users to talk in.
+        """
+
+        async def preparation():
+            await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
+            await self._add_fallback_key_for_device(
+                self._sender_user, self._sender_device, used=True
+            )
+            await self._add_otks_for_device(
+                self._namespaced_user, self._namespaced_device, 36
+            )
+            await self._add_fallback_key_for_device(
+                self._namespaced_user, self._namespaced_device, used=False
+            )
+
+            # Register a device for the real user, too, so that we can later ensure
+            # that we don't leak information to the AS about the non-AS user.
+            await self.hs.get_datastores().main.store_device(
+                self._real_user, "REALDEV", "UltraMatrix 3000"
+            )
+            await self._add_otks_for_device(self._real_user, "REALDEV", 50)
+
+        self.get_success(preparation())
+
+        room_id = self.helper.create_room_as(
+            self._real_user, is_public=True, tok=self._real_user_token
+        )
+        self.helper.join(
+            room_id,
+            self._namespaced_user,
+            tok=self._service_token,
+            appservice_user_id=self._namespaced_user,
+        )
+
+        # Check it was called for sanity. (This was to send the join event to the AS.)
+        self.send_mock.assert_called()
+        self.send_mock.reset_mock()
+
+        return room_id
+
+    @override_config(
+        {"experimental_features": {"msc3202_transaction_extensions": True}}
+    )
+    def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus(
+        self,
+    ) -> None:
+        """
+        Tests that:
+        - the AS receives one-time key counts and unused fallback keys for:
+            - the specified sender; and
+            - any user who is in receipt of the PDUs
+        """
+
+        room_id = self._set_up_devices_and_a_room()
+
+        # Send a message into the AS's room
+        self.helper.send(room_id, "woof woof", tok=self._real_user_token)
+
+        # Capture what was sent as an AS transaction.
+        self.send_mock.assert_called()
+        last_args, _last_kwargs = self.send_mock.call_args
+        otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS]
+        unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
+            self.ARG_FALLBACK_KEYS
+        ]
+
+        self.assertEqual(
+            otks,
+            {
+                "@as.sender:test": {self._sender_device: {"algo": 42}},
+                "@_as_user1:test": {self._namespaced_device: {"algo": 36}},
+            },
+        )
+        self.assertEqual(
+            unused_fallbacks,
+            {
+                "@as.sender:test": {self._sender_device: []},
+                "@_as_user1:test": {self._namespaced_device: ["algo"]},
+            },
+        )
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 03b8b8615c..0c6e55e725 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -129,7 +129,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_mau_limits_exceeded_large(self):
         self.auth_blocking._limit_usage_by_mau = True
-        self.hs.get_datastore().get_monthly_active_count = Mock(
+        self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
 
@@ -140,7 +140,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
 
-        self.hs.get_datastore().get_monthly_active_count = Mock(
+        self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
         self.get_failure(
@@ -156,7 +156,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
 
         # Set the server to be at the edge of too many users.
-        self.hs.get_datastore().get_monthly_active_count = Mock(
+        self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
 
@@ -175,7 +175,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
         # If in monthly active cohort
-        self.hs.get_datastore().user_last_seen_monthly_active = Mock(
+        self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
             return_value=make_awaitable(self.clock.time_msec())
         )
         self.get_success(
@@ -192,7 +192,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_mau_limits_not_exceeded(self):
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.hs.get_datastore().get_monthly_active_count = Mock(
+        self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
         )
         # Ensure does not raise exception
@@ -202,7 +202,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.hs.get_datastore().get_monthly_active_count = Mock(
+        self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
         )
         self.get_success(
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 8705ff8943..a267228846 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -77,7 +77,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
     def test_map_cas_user_to_existing_user(self):
         """Existing users can log in with CAS account."""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 01096a1581..ddda36c5a9 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -34,7 +34,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
         self.user = self.register_user("user", "pass")
         self.token = self.login("user", "pass")
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 43031e07ea..683677fd07 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -28,7 +28,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         return hs
 
     def prepare(self, reactor, clock, hs):
@@ -263,7 +263,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         self.handler = hs.get_device_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         return hs
 
     def test_dehydrate_and_rehydrate_device(self):
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 0ea4e753e2..6e403a87c5 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -46,7 +46,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         self.handler = hs.get_directory_handler()
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.my_room = RoomAlias.from_string("#my-room:test")
         self.your_room = RoomAlias.from_string("#your-room:test")
@@ -63,7 +63,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         result = self.get_success(self.handler.get_association(self.my_room))
 
-        self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
+        self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self):
         self.mock_federation.make_query.return_value = make_awaitable(
@@ -72,7 +72,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         result = self.get_success(self.handler.get_association(self.remote_room))
 
-        self.assertEquals(
+        self.assertEqual(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result
         )
         self.mock_federation.make_query.assert_called_with(
@@ -94,7 +94,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
             self.handler.on_directory_query({"room_alias": "#your-room:test"})
         )
 
-        self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+        self.assertEqual({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
 
 
 class TestCreateAlias(unittest.HomeserverTestCase):
@@ -174,7 +174,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
 
@@ -224,7 +224,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
                 create_requester(self.test_user), self.room_alias
             )
         )
-        self.assertEquals(self.room_id, result)
+        self.assertEqual(self.room_id, result)
 
         # Confirm the alias is gone.
         self.get_failure(
@@ -243,7 +243,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
                 create_requester(self.admin_user), self.room_alias
             )
         )
-        self.assertEquals(self.room_id, result)
+        self.assertEqual(self.room_id, result)
 
         # Confirm the alias is gone.
         self.get_failure(
@@ -269,7 +269,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
                 create_requester(self.test_user), self.room_alias
             )
         )
-        self.assertEquals(self.room_id, result)
+        self.assertEqual(self.room_id, result)
 
         # Confirm the alias is gone.
         self.get_failure(
@@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
 
@@ -411,7 +411,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
             b"directory/room/%23test%3Atest",
             {"room_id": room_id},
         )
-        self.assertEquals(403, channel.code, channel.result)
+        self.assertEqual(403, channel.code, channel.result)
 
     def test_allowed(self):
         room_id = self.helper.create_room_as(self.user_id)
@@ -421,7 +421,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
             b"directory/room/%23unofficial_test%3Atest",
             {"room_id": room_id},
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
     def test_denied_during_creation(self):
         """A room alias that is not allowed should be rejected during creation."""
@@ -443,8 +443,8 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
             "GET",
             b"directory/room/%23unofficial_test%3Atest",
         )
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(channel.json_body["room_id"], room_id)
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(channel.json_body["room_id"], room_id)
 
 
 class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
@@ -572,7 +572,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         self.room_list_handler = hs.get_room_list_handler()
         self.directory_handler = hs.get_directory_handler()
@@ -585,7 +585,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
         # Room list is enabled so we should get some results
         channel = self.make_request("GET", b"publicRooms")
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) > 0)
 
         self.room_list_handler.enable_room_list_search = False
@@ -593,7 +593,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
         # Room list disabled so we should get no results
         channel = self.make_request("GET", b"publicRooms")
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) == 0)
 
         # Room list disabled so we shouldn't be allowed to publish rooms
@@ -601,4 +601,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
-        self.assertEquals(403, channel.code, channel.result)
+        self.assertEqual(403, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 734ed84d78..9338ab92e9 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -34,7 +34,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.handler = hs.get_e2e_keys_handler()
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
     def test_query_local_devices_no_devices(self):
         """If the user has no devices, we expect an empty list."""
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 496b581726..e8b4e39d1a 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -45,7 +45,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state_store = hs.get_storage().state
         self._event_auth_handler = hs.get_event_auth_handler()
         return hs
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 5816295d8b..f4f7ab4845 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -44,7 +44,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
         self.info = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(
+            self.hs.get_datastores().main.get_user_by_access_token(
                 self.access_token,
             )
         )
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a552d8182e..e8418b6638 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -856,7 +856,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler.complete_sso_login.reset_mock()
 
         # Test if the mxid is already taken
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         user3 = UserID.from_string("@test_user_3:test")
         self.get_success(
             store.register_user(user_id=user3.to_string(), password_hash=None)
@@ -872,7 +872,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
     def test_map_userinfo_to_existing_user(self):
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         user = UserID.from_string("@test_user:test")
         self.get_success(
             store.register_user(user_id=user.to_string(), password_hash=None)
@@ -996,7 +996,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 671dc7d083..6ddec9ecf1 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -43,7 +43,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
     servlets = [admin.register_servlets]
 
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
 
     def test_offline_to_online(self):
         wheel_timer = Mock()
@@ -61,11 +61,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
 
         self.assertTrue(persist_and_notify)
         self.assertTrue(state.currently_active)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(new_state.status_msg, state.status_msg)
-        self.assertEquals(state.last_federation_update_ts, now)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(new_state.status_msg, state.status_msg)
+        self.assertEqual(state.last_federation_update_ts, now)
 
-        self.assertEquals(wheel_timer.insert.call_count, 3)
+        self.assertEqual(wheel_timer.insert.call_count, 3)
         wheel_timer.insert.assert_has_calls(
             [
                 call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
@@ -104,11 +104,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
         self.assertFalse(persist_and_notify)
         self.assertTrue(federation_ping)
         self.assertTrue(state.currently_active)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(new_state.status_msg, state.status_msg)
-        self.assertEquals(state.last_federation_update_ts, now)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(new_state.status_msg, state.status_msg)
+        self.assertEqual(state.last_federation_update_ts, now)
 
-        self.assertEquals(wheel_timer.insert.call_count, 3)
+        self.assertEqual(wheel_timer.insert.call_count, 3)
         wheel_timer.insert.assert_has_calls(
             [
                 call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
@@ -149,11 +149,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
         self.assertFalse(persist_and_notify)
         self.assertTrue(federation_ping)
         self.assertTrue(state.currently_active)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(new_state.status_msg, state.status_msg)
-        self.assertEquals(state.last_federation_update_ts, now)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(new_state.status_msg, state.status_msg)
+        self.assertEqual(state.last_federation_update_ts, now)
 
-        self.assertEquals(wheel_timer.insert.call_count, 3)
+        self.assertEqual(wheel_timer.insert.call_count, 3)
         wheel_timer.insert.assert_has_calls(
             [
                 call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
@@ -191,11 +191,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
 
         self.assertTrue(persist_and_notify)
         self.assertFalse(state.currently_active)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(new_state.status_msg, state.status_msg)
-        self.assertEquals(state.last_federation_update_ts, now)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(new_state.status_msg, state.status_msg)
+        self.assertEqual(state.last_federation_update_ts, now)
 
-        self.assertEquals(wheel_timer.insert.call_count, 2)
+        self.assertEqual(wheel_timer.insert.call_count, 2)
         wheel_timer.insert.assert_has_calls(
             [
                 call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
@@ -227,10 +227,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
         self.assertFalse(persist_and_notify)
         self.assertFalse(federation_ping)
         self.assertFalse(state.currently_active)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(new_state.status_msg, state.status_msg)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(new_state.status_msg, state.status_msg)
 
-        self.assertEquals(wheel_timer.insert.call_count, 1)
+        self.assertEqual(wheel_timer.insert.call_count, 1)
         wheel_timer.insert.assert_has_calls(
             [
                 call(
@@ -259,10 +259,10 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertTrue(persist_and_notify)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(state.last_federation_update_ts, now)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(state.last_federation_update_ts, now)
 
-        self.assertEquals(wheel_timer.insert.call_count, 0)
+        self.assertEqual(wheel_timer.insert.call_count, 0)
 
     def test_online_to_idle(self):
         wheel_timer = Mock()
@@ -281,12 +281,12 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertTrue(persist_and_notify)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(state.last_federation_update_ts, now)
-        self.assertEquals(new_state.state, state.state)
-        self.assertEquals(new_state.status_msg, state.status_msg)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(state.last_federation_update_ts, now)
+        self.assertEqual(new_state.state, state.state)
+        self.assertEqual(new_state.status_msg, state.status_msg)
 
-        self.assertEquals(wheel_timer.insert.call_count, 1)
+        self.assertEqual(wheel_timer.insert.call_count, 1)
         wheel_timer.insert.assert_has_calls(
             [
                 call(
@@ -357,8 +357,8 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
-        self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
-        self.assertEquals(new_state.status_msg, status_msg)
+        self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
+        self.assertEqual(new_state.status_msg, status_msg)
 
     def test_busy_no_idle(self):
         """
@@ -380,8 +380,8 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
-        self.assertEquals(new_state.state, PresenceState.BUSY)
-        self.assertEquals(new_state.status_msg, status_msg)
+        self.assertEqual(new_state.state, PresenceState.BUSY)
+        self.assertEqual(new_state.status_msg, status_msg)
 
     def test_sync_timeout(self):
         user_id = "@foo:bar"
@@ -399,8 +399,8 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
-        self.assertEquals(new_state.state, PresenceState.OFFLINE)
-        self.assertEquals(new_state.status_msg, status_msg)
+        self.assertEqual(new_state.state, PresenceState.OFFLINE)
+        self.assertEqual(new_state.status_msg, status_msg)
 
     def test_sync_online(self):
         user_id = "@foo:bar"
@@ -420,8 +420,8 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         )
 
         self.assertIsNotNone(new_state)
-        self.assertEquals(new_state.state, PresenceState.ONLINE)
-        self.assertEquals(new_state.status_msg, status_msg)
+        self.assertEqual(new_state.state, PresenceState.ONLINE)
+        self.assertEqual(new_state.status_msg, status_msg)
 
     def test_federation_ping(self):
         user_id = "@foo:bar"
@@ -440,7 +440,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
-        self.assertEquals(state, new_state)
+        self.assertEqual(state, new_state)
 
     def test_no_timeout(self):
         user_id = "@foo:bar"
@@ -477,8 +477,8 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         )
 
         self.assertIsNotNone(new_state)
-        self.assertEquals(new_state.state, PresenceState.OFFLINE)
-        self.assertEquals(new_state.status_msg, status_msg)
+        self.assertEqual(new_state.state, PresenceState.OFFLINE)
+        self.assertEqual(new_state.status_msg, status_msg)
 
     def test_last_active(self):
         user_id = "@foo:bar"
@@ -497,7 +497,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
-        self.assertEquals(state, new_state)
+        self.assertEqual(state, new_state)
 
 
 class PresenceHandlerTestCase(unittest.HomeserverTestCase):
@@ -891,7 +891,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         # self.event_builder_for_2 = EventBuilderFactory(hs)
         # self.event_builder_for_2.hostname = "test2"
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 60235e5699..972cbac6e4 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -48,7 +48,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor, clock, hs: HomeServer):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.frank = UserID.from_string("@1234abcd:test")
         self.bob = UserID.from_string("@4567:test")
@@ -65,7 +65,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         displayname = self.get_success(self.handler.get_displayname(self.frank))
 
-        self.assertEquals("Frank", displayname)
+        self.assertEqual("Frank", displayname)
 
     def test_set_my_name(self):
         self.get_success(
@@ -74,7 +74,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(
+        self.assertEqual(
             (
                 self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
@@ -90,7 +90,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(
+        self.assertEqual(
             (
                 self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
@@ -118,7 +118,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
 
-        self.assertEquals(
+        self.assertEqual(
             (
                 self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
@@ -150,7 +150,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         displayname = self.get_success(self.handler.get_displayname(self.alice))
 
-        self.assertEquals(displayname, "Alice")
+        self.assertEqual(displayname, "Alice")
         self.mock_federation.make_query.assert_called_with(
             destination="remote",
             query_type="profile",
@@ -172,7 +172,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals({"displayname": "Caroline"}, response)
+        self.assertEqual({"displayname": "Caroline"}, response)
 
     def test_get_my_avatar(self):
         self.get_success(
@@ -182,7 +182,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
 
-        self.assertEquals("http://my.server/me.png", avatar_url)
+        self.assertEqual("http://my.server/me.png", avatar_url)
 
     def test_set_my_avatar(self):
         self.get_success(
@@ -193,7 +193,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(
+        self.assertEqual(
             (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/pic.gif",
         )
@@ -207,7 +207,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(
+        self.assertEqual(
             (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
@@ -235,7 +235,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(
+        self.assertEqual(
             (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
@@ -325,7 +325,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
                 properties are "mimetype" (for the file's type) and "size" (for the
                 file's size).
         """
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         for name, props in names_and_props.items():
             self.get_success(
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 5de89c873b..5081b97573 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -314,4 +314,4 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
     ):
         """Tests that the _filter_out_hidden returns the expected output"""
         filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org")
-        self.assertEquals(filtered_events, expected_output)
+        self.assertEqual(filtered_events, expected_output)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cd6f2c77ae..45fd30cf43 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -154,7 +154,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.handler = self.hs.get_registration_handler()
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.lots_of_users = 100
         self.small_number_of_users = 1
 
@@ -167,12 +167,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         result_user_id, result_token = self.get_success(
             self.get_or_create_user(requester, frank.localpart, "Frankie")
         )
-        self.assertEquals(result_user_id, user_id)
+        self.assertEqual(result_user_id, user_id)
         self.assertIsInstance(result_token, str)
         self.assertGreater(len(result_token), 20)
 
     def test_if_user_exists(self):
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         frank = UserID.from_string("@frank:test")
         self.get_success(
             store.register_user(user_id=frank.to_string(), password_hash=None)
@@ -183,7 +183,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         result_user_id, result_token = self.get_success(
             self.get_or_create_user(requester, local_part, None)
         )
-        self.assertEquals(result_user_id, user_id)
+        self.assertEqual(result_user_id, user_id)
         self.assertTrue(result_token is not None)
 
     @override_config({"limit_usage_by_mau": False})
@@ -760,7 +760,7 @@ class RemoteAutoJoinTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.handler = self.hs.get_registration_handler()
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
     @override_config({"auto_join_rooms": ["#room:remotetest"]})
     def test_auto_create_auto_join_remote_room(self):
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 51b22d2998..cff07a8973 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -157,35 +157,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             state_key=room_id,
         )
 
-    def _assert_rooms(
-        self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]]
-    ) -> None:
-        """
-        Assert that the expected room IDs and events are in the response.
-
-        Args:
-            result: The result from the API call.
-            rooms_and_children: An iterable of tuples where each tuple is:
-                The expected room ID.
-                The expected IDs of any children rooms.
-        """
-        room_ids = []
-        children_ids = []
-        for room_id, children in rooms_and_children:
-            room_ids.append(room_id)
-            if children:
-                children_ids.extend([(room_id, child_id) for child_id in children])
-        self.assertCountEqual(
-            [room.get("room_id") for room in result["rooms"]], room_ids
-        )
-        self.assertCountEqual(
-            [
-                (event.get("room_id"), event.get("state_key"))
-                for event in result["events"]
-            ],
-            children_ids,
-        )
-
     def _assert_hierarchy(
         self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]]
     ) -> None:
@@ -251,11 +222,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
 
     def test_simple_space(self):
         """Test a simple space with a single room."""
-        result = self.get_success(self.handler.get_space_summary(self.user, self.space))
         # The result should have the space and the room in it, along with a link
         # from space -> room.
         expected = [(self.space, [self.room]), (self.room, ())]
-        self._assert_rooms(result, expected)
 
         result = self.get_success(
             self.handler.get_room_hierarchy(create_requester(self.user), self.space)
@@ -271,12 +240,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             self._add_child(self.space, room, self.token)
             rooms.append(room)
 
-        result = self.get_success(self.handler.get_space_summary(self.user, self.space))
-        # The spaces result should have the space and the first 50 rooms in it,
-        # along with the links from space -> room for those 50 rooms.
-        expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]]
-        self._assert_rooms(result, expected)
-
         # The result should have the space and the rooms in it, along with the links
         # from space -> room.
         expected = [(self.space, rooms)] + [(room, []) for room in rooms]
@@ -300,10 +263,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         token2 = self.login("user2", "pass")
 
         # The user can see the space since it is publicly joinable.
-        result = self.get_success(self.handler.get_space_summary(user2, self.space))
         expected = [(self.space, [self.room]), (self.room, ())]
-        self._assert_rooms(result, expected)
-
         result = self.get_success(
             self.handler.get_room_hierarchy(create_requester(user2), self.space)
         )
@@ -316,7 +276,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             body={"join_rule": JoinRules.INVITE},
             tok=self.token,
         )
-        self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
         self.get_failure(
             self.handler.get_room_hierarchy(create_requester(user2), self.space),
             AuthError,
@@ -329,9 +288,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             body={"history_visibility": HistoryVisibility.WORLD_READABLE},
             tok=self.token,
         )
-        result = self.get_success(self.handler.get_space_summary(user2, self.space))
-        self._assert_rooms(result, expected)
-
         result = self.get_success(
             self.handler.get_room_hierarchy(create_requester(user2), self.space)
         )
@@ -344,7 +300,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             body={"history_visibility": HistoryVisibility.JOINED},
             tok=self.token,
         )
-        self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
         self.get_failure(
             self.handler.get_room_hierarchy(create_requester(user2), self.space),
             AuthError,
@@ -353,9 +308,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         # Join the space and results should be returned.
         self.helper.invite(self.space, targ=user2, tok=self.token)
         self.helper.join(self.space, user2, tok=token2)
-        result = self.get_success(self.handler.get_space_summary(user2, self.space))
-        self._assert_rooms(result, expected)
-
         result = self.get_success(
             self.handler.get_room_hierarchy(create_requester(user2), self.space)
         )
@@ -363,10 +315,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
 
         # Attempting to view an unknown room returns the same error.
         self.get_failure(
-            self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname),
-            AuthError,
-        )
-        self.get_failure(
             self.handler.get_room_hierarchy(
                 create_requester(user2), "#not-a-space:" + self.hs.hostname
             ),
@@ -496,7 +444,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
 
         # Join the space.
         self.helper.join(self.space, user2, tok=token2)
-        result = self.get_success(self.handler.get_space_summary(user2, self.space))
         expected = [
             (
                 self.space,
@@ -520,7 +467,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             (world_readable_room, ()),
             (joined_room, ()),
         ]
-        self._assert_rooms(result, expected)
 
         result = self.get_success(
             self.handler.get_room_hierarchy(create_requester(user2), self.space)
@@ -554,8 +500,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         self._add_child(subspace, self.room, token=self.token)
         self._add_child(subspace, room2, self.token)
 
-        result = self.get_success(self.handler.get_space_summary(self.user, self.space))
-
         # The result should include each room a single time and each link.
         expected = [
             (self.space, [self.room, room2, subspace]),
@@ -563,7 +507,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             (subspace, [subroom, self.room, room2]),
             (subroom, ()),
         ]
-        self._assert_rooms(result, expected)
 
         result = self.get_success(
             self.handler.get_room_hierarchy(create_requester(self.user), self.space)
@@ -715,7 +658,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
 
     def test_unknown_room_version(self):
         """
-        If an room with an unknown room version is encountered it should not cause
+        If a room with an unknown room version is encountered it should not cause
         the entire summary to skip.
         """
         # Poke the database and update the room version to an unknown one.
@@ -727,11 +670,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
                 desc="updated-room-version",
             )
         )
+        # Invalidate method so that it returns the currently updated version
+        # instead of the cached version.
+        self.hs.get_datastores().main.get_room_version_id.invalidate((self.room,))
 
-        result = self.get_success(self.handler.get_space_summary(self.user, self.space))
         # The result should have only the space, along with a link from space -> room.
         expected = [(self.space, [self.room])]
-        self._assert_rooms(result, expected)
 
         result = self.get_success(
             self.handler.get_room_hierarchy(create_requester(self.user), self.space)
@@ -775,41 +719,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             "world_readable": True,
         }
 
-        async def summarize_remote_room(
-            _self, room, suggested_only, max_children, exclude_rooms
-        ):
-            return [
-                requested_room_entry,
-                _RoomEntry(
-                    subroom,
-                    {
-                        "room_id": subroom,
-                        "world_readable": True,
-                    },
-                ),
-            ]
-
         async def summarize_remote_room_hierarchy(_self, room, suggested_only):
             return requested_room_entry, {subroom: child_room}, set()
 
         # Add a room to the space which is on another server.
         self._add_child(self.space, subspace, self.token)
 
-        with mock.patch(
-            "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room",
-            new=summarize_remote_room,
-        ):
-            result = self.get_success(
-                self.handler.get_space_summary(self.user, self.space)
-            )
-
         expected = [
             (self.space, [self.room, subspace]),
             (self.room, ()),
             (subspace, [subroom]),
             (subroom, ()),
         ]
-        self._assert_rooms(result, expected)
 
         with mock.patch(
             "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
@@ -881,7 +802,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
                     "room_id": restricted_room,
                     "world_readable": False,
                     "join_rules": JoinRules.RESTRICTED,
-                    "allowed_spaces": [],
+                    "allowed_room_ids": [],
                 },
             ),
             (
@@ -890,7 +811,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
                     "room_id": restricted_accessible_room,
                     "world_readable": False,
                     "join_rules": JoinRules.RESTRICTED,
-                    "allowed_spaces": [self.room],
+                    "allowed_room_ids": [self.room],
                 },
             ),
             (
@@ -929,30 +850,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             ],
         )
 
-        async def summarize_remote_room(
-            _self, room, suggested_only, max_children, exclude_rooms
-        ):
-            return [subspace_room_entry] + [
-                # A copy is made of the room data since the allowed_spaces key
-                # is removed.
-                _RoomEntry(child_room[0], dict(child_room[1]))
-                for child_room in children_rooms
-            ]
-
         async def summarize_remote_room_hierarchy(_self, room, suggested_only):
             return subspace_room_entry, dict(children_rooms), set()
 
         # Add a room to the space which is on another server.
         self._add_child(self.space, subspace, self.token)
 
-        with mock.patch(
-            "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room",
-            new=summarize_remote_room,
-        ):
-            result = self.get_success(
-                self.handler.get_space_summary(self.user, self.space)
-            )
-
         expected = [
             (self.space, [self.room, subspace]),
             (self.room, ()),
@@ -976,7 +879,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             (world_readable_room, ()),
             (joined_room, ()),
         ]
-        self._assert_rooms(result, expected)
 
         with mock.patch(
             "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
@@ -1010,31 +912,17 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        async def summarize_remote_room(
-            _self, room, suggested_only, max_children, exclude_rooms
-        ):
-            return [fed_room_entry]
-
         async def summarize_remote_room_hierarchy(_self, room, suggested_only):
             return fed_room_entry, {}, set()
 
         # Add a room to the space which is on another server.
         self._add_child(self.space, fed_room, self.token)
 
-        with mock.patch(
-            "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room",
-            new=summarize_remote_room,
-        ):
-            result = self.get_success(
-                self.handler.get_space_summary(self.user, self.space)
-            )
-
         expected = [
             (self.space, [self.room, fed_room]),
             (self.room, ()),
             (fed_room, ()),
         ]
-        self._assert_rooms(result, expected)
 
         with mock.patch(
             "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy",
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 50551aa6e3..23941abed8 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -142,7 +142,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
     def test_map_saml_response_to_existing_user(self):
         """Existing users can log in with SAML account."""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
@@ -217,7 +217,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         sso_handler.render_error = Mock(return_value=None)
 
         # register a user to occupy the first-choice MXID
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 56207f4db6..ecd78fa369 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -33,7 +33,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.handler = self.hs.get_stats_handler()
 
     def _add_background_updates(self):
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 07a760e91a..3aedc0767b 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -41,7 +41,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs: HomeServer):
         self.sync_handler = self.hs.get_sync_handler()
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
@@ -69,7 +69,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             self.sync_handler.wait_for_sync_for_user(requester, sync_config),
             ResourceLimitError,
         )
-        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
         self.auth_blocking._hs_disabled = False
 
@@ -80,7 +80,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             self.sync_handler.wait_for_sync_for_user(requester, sync_config),
             ResourceLimitError,
         )
-        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     def test_unknown_room_version(self):
         """
@@ -122,7 +122,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
             b"{}",
             tok,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # The rooms should appear in the sync response.
         result = self.get_success(
@@ -248,7 +248,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # the prev_events used when creating the join event, such that the ban does not
         # precede the join.
         mocked_get_prev_events = patch.object(
-            self.hs.get_datastore(),
+            self.hs.get_datastores().main,
             "get_prev_events_for_room",
             new_callable=MagicMock,
             return_value=make_awaitable([last_room_creation_event_id]),
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 000f9b9fde..f91a80b9fa 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -91,7 +91,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.event_source = hs.get_event_sources().sources.typing
 
-        self.datastore = hs.get_datastore()
+        self.datastore = hs.get_datastores().main
         self.datastore.get_destination_retry_timings = Mock(
             return_value=defer.succeed(None)
         )
@@ -156,7 +156,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
     def test_started_typing_local(self):
         self.room_members = [U_APPLE, U_BANANA]
 
-        self.assertEquals(self.event_source.get_current_key(), 0)
+        self.assertEqual(self.event_source.get_current_key(), 0)
 
         self.get_success(
             self.handler.started_typing(
@@ -169,13 +169,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        self.assertEquals(self.event_source.get_current_key(), 1)
+        self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [
                 {
@@ -220,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
     def test_started_typing_remote_recv(self):
         self.room_members = [U_APPLE, U_ONION]
 
-        self.assertEquals(self.event_source.get_current_key(), 0)
+        self.assertEqual(self.event_source.get_current_key(), 0)
 
         channel = self.make_request(
             "PUT",
@@ -239,13 +239,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        self.assertEquals(self.event_source.get_current_key(), 1)
+        self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [
                 {
@@ -259,7 +259,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
     def test_started_typing_remote_recv_not_in_room(self):
         self.room_members = [U_APPLE, U_ONION]
 
-        self.assertEquals(self.event_source.get_current_key(), 0)
+        self.assertEqual(self.event_source.get_current_key(), 0)
 
         channel = self.make_request(
             "PUT",
@@ -278,7 +278,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_not_called()
 
-        self.assertEquals(self.event_source.get_current_key(), 0)
+        self.assertEqual(self.event_source.get_current_key(), 0)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=U_APPLE,
@@ -288,8 +288,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 is_guest=False,
             )
         )
-        self.assertEquals(events[0], [])
-        self.assertEquals(events[1], 0)
+        self.assertEqual(events[0], [])
+        self.assertEqual(events[1], 0)
 
     @override_config({"send_federation": True})
     def test_stopped_typing(self):
@@ -302,7 +302,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.handler._member_typing_until[member] = 1002000
         self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()}
 
-        self.assertEquals(self.event_source.get_current_key(), 0)
+        self.assertEqual(self.event_source.get_current_key(), 0)
 
         self.get_success(
             self.handler.stopped_typing(
@@ -332,13 +332,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             try_trailing_slash_on_400=True,
         )
 
-        self.assertEquals(self.event_source.get_current_key(), 1)
+        self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
         )
@@ -346,7 +346,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
     def test_typing_timeout(self):
         self.room_members = [U_APPLE, U_BANANA]
 
-        self.assertEquals(self.event_source.get_current_key(), 0)
+        self.assertEqual(self.event_source.get_current_key(), 0)
 
         self.get_success(
             self.handler.started_typing(
@@ -360,7 +360,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
         self.on_new_event.reset_mock()
 
-        self.assertEquals(self.event_source.get_current_key(), 1)
+        self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=U_APPLE,
@@ -370,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 is_guest=False,
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [
                 {
@@ -385,7 +385,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
 
-        self.assertEquals(self.event_source.get_current_key(), 2)
+        self.assertEqual(self.event_source.get_current_key(), 2)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=U_APPLE,
@@ -395,7 +395,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 is_guest=False,
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
         )
@@ -414,7 +414,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])])
         self.on_new_event.reset_mock()
 
-        self.assertEquals(self.event_source.get_current_key(), 3)
+        self.assertEqual(self.event_source.get_current_key(), 3)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=U_APPLE,
@@ -424,7 +424,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 is_guest=False,
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [
                 {
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 482c90ef68..92012cd6f7 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -77,7 +77,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.handler = hs.get_user_directory_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
         self.event_creation_handler = self.hs.get_event_creation_handler()
@@ -1042,7 +1042,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
             b'{"search_term":"user2"}',
             access_token=u1_token,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["results"]) > 0)
 
         # Disable user directory and check search returns nothing
@@ -1053,5 +1053,5 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
             b'{"search_term":"user2"}',
             access_token=u1_token,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["results"]) == 0)
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index c49be33b9f..77ce8432ac 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -65,9 +65,9 @@ class SrvResolverTestCase(unittest.TestCase):
 
         servers = self.successResultOf(test_d)
 
-        self.assertEquals(len(servers), 1)
-        self.assertEquals(servers, cache[service_name])
-        self.assertEquals(servers[0].host, host_name)
+        self.assertEqual(len(servers), 1)
+        self.assertEqual(servers, cache[service_name])
+        self.assertEqual(servers[0].host, host_name)
 
     @defer.inlineCallbacks
     def test_from_cache_expired_and_dns_fail(self):
@@ -88,8 +88,8 @@ class SrvResolverTestCase(unittest.TestCase):
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
 
-        self.assertEquals(len(servers), 1)
-        self.assertEquals(servers, cache[service_name])
+        self.assertEqual(len(servers), 1)
+        self.assertEqual(servers, cache[service_name])
 
     @defer.inlineCallbacks
     def test_from_cache(self):
@@ -114,8 +114,8 @@ class SrvResolverTestCase(unittest.TestCase):
 
         self.assertFalse(dns_client_mock.lookupService.called)
 
-        self.assertEquals(len(servers), 1)
-        self.assertEquals(servers, cache[service_name])
+        self.assertEqual(len(servers), 1)
+        self.assertEqual(servers, cache[service_name])
 
     @defer.inlineCallbacks
     def test_empty_cache(self):
@@ -144,8 +144,8 @@ class SrvResolverTestCase(unittest.TestCase):
 
         servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
-        self.assertEquals(len(servers), 0)
-        self.assertEquals(len(cache), 0)
+        self.assertEqual(len(servers), 0)
+        self.assertEqual(len(cache), 0)
 
     def test_disabled_service(self):
         """
@@ -201,6 +201,6 @@ class SrvResolverTestCase(unittest.TestCase):
 
         servers = self.successResultOf(resolve_d)
 
-        self.assertEquals(len(servers), 1)
-        self.assertEquals(servers, cache[service_name])
-        self.assertEquals(servers[0].host, b"host")
+        self.assertEqual(len(servers), 1)
+        self.assertEqual(servers, cache[service_name])
+        self.assertEqual(servers[0].host, b"host")
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index d16cd141a7..c3f20f9692 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -41,7 +41,7 @@ class ModuleApiTestCase(HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
         self.module_api = homeserver.get_module_api()
         self.event_creation_handler = homeserver.get_event_creation_handler()
         self.sync_handler = homeserver.get_sync_handler()
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index f8cba7b645..7a3b0d6755 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -102,13 +102,13 @@ class EmailPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(self.access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
         )
         self.token_id = user_tuple.token_id
 
         # We need to add email to account before we can create a pusher.
         self.get_success(
-            hs.get_datastore().user_add_threepid(
+            hs.get_datastores().main.user_add_threepid(
                 self.user_id, "email", "a@example.com", 0, 0
             )
         )
@@ -128,7 +128,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
 
         self.auth_handler = hs.get_auth_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def test_need_validated_email(self):
         """Test that we can only add an email pusher if the user has validated
@@ -375,7 +375,7 @@ class EmailPusherTests(HomeserverTestCase):
 
         # check that the pusher for that email address has been deleted
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
@@ -388,14 +388,14 @@ class EmailPusherTests(HomeserverTestCase):
         # This resembles the old behaviour, which the background update below is intended
         # to clean up.
         self.get_success(
-            self.hs.get_datastore().user_delete_threepid(
+            self.hs.get_datastores().main.user_delete_threepid(
                 self.user_id, "email", "a@example.com"
             )
         )
 
         # Run the "remove_deleted_email_pushers" background job
         self.get_success(
-            self.hs.get_datastore().db_pool.simple_insert(
+            self.hs.get_datastores().main.db_pool.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "remove_deleted_email_pushers",
@@ -406,14 +406,14 @@ class EmailPusherTests(HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.hs.get_datastore().db_pool.updates._all_done = False
+        self.hs.get_datastores().main.db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
         self.wait_for_background_updates()
 
         # Check that all pushers with unlinked addresses were deleted
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
@@ -428,7 +428,7 @@ class EmailPusherTests(HomeserverTestCase):
         """
         # Get the stream ordering before it gets sent
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
@@ -439,7 +439,7 @@ class EmailPusherTests(HomeserverTestCase):
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
@@ -458,7 +458,7 @@ class EmailPusherTests(HomeserverTestCase):
 
         # The stream ordering has increased
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index e1e3fb97c5..c284beb37c 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -62,7 +62,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
@@ -108,7 +108,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
@@ -138,7 +138,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Get the stream ordering before it gets sent
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
@@ -149,7 +149,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
@@ -170,7 +170,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # The stream ordering has increased
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
@@ -192,7 +192,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # The stream ordering has increased, again
         pushers = self.get_success(
-            self.hs.get_datastore().get_pushers_by({"user_name": user_id})
+            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
@@ -224,7 +224,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
@@ -344,7 +344,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
@@ -430,7 +430,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
@@ -507,7 +507,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
@@ -613,7 +613,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         # Register the pusher
         user_tuple = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index a52e89e407..3849beb9d6 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -14,6 +14,8 @@
 
 from typing import Any, Dict
 
+import frozendict
+
 from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent
 from synapse.push import push_rule_evaluator
@@ -191,6 +193,13 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             "pattern should only match at the start/end of the value",
         )
 
+        # it should work on frozendicts too
+        self._assert_matches(
+            condition,
+            frozendict.frozendict({"value": "FoobaZ"}),
+            "patterns should match on frozendicts",
+        )
+
         # wildcards should match
         condition = {
             "kind": "event_match",
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9fc50f8852..a7a05a564f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -68,7 +68,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         # Since we use sqlite in memory databases we need to make sure the
         # databases objects are the same.
-        self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
+        self.worker_hs.get_datastores().main.db_pool = hs.get_datastores().main.db_pool
 
         # Normally we'd pass in the handler to `setup_test_homeserver`, which would
         # eventually hit "Install @cache_in_self attributes" in tests/utils.py.
@@ -233,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # We may have an attempt to connect to redis for the external cache already.
         self.connect_any_redis_attempts()
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.database_pool = store.db_pool
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -332,7 +332,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
                 lambda: self._handle_http_replication_attempt(worker_hs, port),
             )
 
-        store = worker_hs.get_datastore()
+        store = worker_hs.get_datastores().main
         store.db_pool._db_pool = self.database_pool._db_pool
 
         # Set up TCP replication between master and the new worker if we don't
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 83e89383f6..85be79d19d 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -30,8 +30,8 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
 
         self.reconnect()
 
-        self.master_store = hs.get_datastore()
-        self.slaved_store = self.worker_hs.get_datastore()
+        self.master_store = hs.get_datastores().main
+        self.slaved_store = self.worker_hs.get_datastores().main
         self.storage = hs.get_storage()
 
     def replicate(self):
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index eca6a443af..17dc42fd37 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -59,7 +59,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
     def setUp(self):
         # Patch up the equality operator for events so that we can check
-        # whether lists of events match using assertEquals
+        # whether lists of events match using assertEqual
         self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
         return super().setUp()
 
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index cdd052001b..50fbff5f32 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -23,7 +23,7 @@ from tests.replication._base import BaseStreamTestCase
 class AccountDataStreamTestCase(BaseStreamTestCase):
     def test_update_function_room_account_data_limit(self):
         """Test replication with many room account data updates"""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # generate lots of account data updates
         updates = []
@@ -69,7 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
 
     def test_update_function_global_account_data_limit(self):
         """Test replication with many global account data updates"""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # generate lots of account data updates
         updates = []
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f198a94887..f9d5da723c 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -136,7 +136,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         # this is the point in the DAG where we make a fork
         fork_point: List[str] = self.get_success(
-            self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
+            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
         events = [
@@ -291,7 +291,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         # this is the point in the DAG where we make a fork
         fork_point: List[str] = self.get_success(
-            self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
+            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
         events: List[EventBase] = []
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 38e292c1ab..eb00117845 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -32,7 +32,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
 
         # tell the master to send a new receipt
         self.get_success(
-            self.hs.get_datastore().insert_receipt(
+            self.hs.get_datastores().main.insert_receipt(
                 "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
             )
         )
@@ -56,7 +56,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         self.test_handler.on_rdata.reset_mock()
 
         self.get_success(
-            self.hs.get_datastore().insert_receipt(
+            self.hs.get_datastores().main.insert_receipt(
                 "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
             )
         )
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 92a5b53e11..ba1a63c0d6 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -204,7 +204,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
 
     def create_room_with_remote_server(self, user, token, remote_server="other_server"):
         room = self.helper.create_room_as(user, tok=token)
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         federation = self.hs.get_federation_event_handler()
 
         prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 4094a75f36..8f4f6688ce 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -50,7 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Register a pusher
         user_dict = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(access_token)
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_dict.token_id
 
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 596ba5a0c9..5f142e84c3 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -47,7 +47,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.other_access_token = self.login("otheruser", "pass")
 
         self.room_creator = self.hs.get_room_creation_handler()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def default_config(self):
         conf = super().default_config()
@@ -99,7 +99,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         persisted_on_1 = False
         persisted_on_2 = False
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         user_id = self.register_user("user", "pass")
         access_token = self.login("user", "pass")
@@ -166,7 +166,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         user_id = self.register_user("user", "pass")
         access_token = self.login("user", "pass")
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Create two room on the different workers.
         self._create_room(room_id1, user_id, access_token)
@@ -194,7 +194,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         #
         # Worker2's event stream position will not advance until we call
         # __aexit__ again.
-        worker_store2 = worker_hs2.get_datastore()
+        worker_store2 = worker_hs2.get_datastores().main
         assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
 
         actx = worker_store2._stream_id_gen.get_next()
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 1e3fe9c62c..fb36aa9940 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -36,7 +36,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 71068d16cd..929bbdc37d 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -35,7 +35,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
@@ -537,7 +537,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 86aff7575c..0d47dd0aff 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -634,7 +634,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         media_repo = hs.get_media_repository_resource()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.server_name = hs.hostname
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -767,7 +767,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         media_repo = hs.get_media_repository_resource()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 8513b1d2df..8354250ec2 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -34,7 +34,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 23da0ad736..95282f078e 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -50,7 +50,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
         self.event_creation_handler._consent_uri_builder = consent_uri_builder
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -465,7 +465,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
         self.event_creation_handler._consent_uri_builder = consent_uri_builder
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -1909,7 +1909,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
 
     def test_join_private_room_if_not_member(self) -> None:
@@ -1957,7 +1957,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.admin_user_tok,
         )
-        self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
         # Join user to room.
@@ -1980,7 +1980,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
     def test_join_private_room_if_owner(self) -> None:
@@ -2010,7 +2010,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/r0/joined_rooms",
             access_token=self.second_tok,
         )
-        self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
     def test_context_as_non_admin(self) -> None:
@@ -2044,7 +2044,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
                 % (room_id, events[midway]["event_id"]),
                 access_token=tok,
             )
-            self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+            self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
             self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_context_as_admin(self) -> None:
@@ -2074,8 +2074,8 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
             % (room_id, events[midway]["event_id"]),
             access_token=self.admin_user_tok,
         )
-        self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
-        self.assertEquals(
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(
             channel.json_body["event"]["event_id"], events[midway]["event_id"]
         )
 
@@ -2239,7 +2239,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self._store = hs.get_datastore()
+        self._store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index 3c59f5f766..2c855bff99 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -38,7 +38,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
         self.pagination_handler = hs.get_pagination_handler()
         self.server_notices_manager = self.hs.get_server_notices_manager()
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 272637e965..a60ea0a563 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -410,7 +410,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         even if the MAU limit is reached.
         """
         handler = self.hs.get_registration_handler()
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Set monthly active users to the limit
         store.get_monthly_active_count = Mock(
@@ -455,7 +455,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
     url = "/_synapse/admin/v2/users"
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -913,7 +913,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -1167,7 +1167,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.auth_handler = hs.get_auth_handler()
 
         # create users and get access tokens
@@ -2609,7 +2609,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -2737,7 +2737,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.media_repo = hs.get_media_repository_resource()
         self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
 
@@ -3317,7 +3317,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -3609,7 +3609,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -3687,7 +3687,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -3913,7 +3913,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 51146c471d..6c4462e74a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1,6 +1,4 @@
-# Copyright 2015-2016 OpenMarket Ltd
-# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2022 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,16 +15,22 @@ import json
 import os
 import re
 from email.parser import Parser
-from typing import Optional
+from typing import Dict, List, Optional
+from unittest.mock import Mock
 
 import pkg_resources
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import LoginType, Membership
 from synapse.api.errors import Codes, HttpResponseException
 from synapse.appservice import ApplicationService
+from synapse.rest import admin
 from synapse.rest.client import account, login, register, room
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeSite, make_request
@@ -73,7 +77,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
 
     def test_basic_password_reset(self):
@@ -100,7 +104,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         client_secret = "foobar"
         session_id = self._request_token(email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEqual(len(self.email_attempts), 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
@@ -139,7 +143,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
             client_secret = "foobar"
             session_id = self._request_token(email, client_secret, ip)
 
-            self.assertEquals(len(self.email_attempts), 1)
+            self.assertEqual(len(self.email_attempts), 1)
             link = self._get_link_from_email()
 
             self._validate_token(link)
@@ -189,7 +193,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         client_secret = "foobar"
         session_id = self._request_token(email_passwort_reset, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEqual(len(self.email_attempts), 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
@@ -226,7 +230,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         client_secret = "foobar"
         session_id = self._request_token(email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEqual(len(self.email_attempts), 1)
 
         # Attempt to reset password without clicking the link
         self._reset_password(new_password, session_id, client_secret, expected_code=401)
@@ -318,7 +322,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
             shorthand=False,
         )
 
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # Now POST to the same endpoint, mimicking the same behaviour as clicking the
         # password reset confirm button
@@ -333,7 +337,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
             shorthand=False,
             content_is_form=True,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
     def _get_link_from_email(self):
         assert self.email_attempts, "No emails have been sent"
@@ -372,7 +376,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
                 },
             },
         )
-        self.assertEquals(expected_code, channel.code, channel.result)
+        self.assertEqual(expected_code, channel.code, channel.result)
 
 
 class DeactivateTestCase(unittest.HomeserverTestCase):
@@ -394,7 +398,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
 
         self.deactivate(user_id, tok)
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         # Check that the user has been marked as deactivated.
         self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
@@ -405,7 +409,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
 
     def test_pending_invites(self):
         """Tests that deactivating a user rejects every pending invite for them."""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         inviter_id = self.register_user("inviter", "test")
         inviter_tok = self.login("inviter", "test")
@@ -523,7 +527,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             namespaces={"users": [{"regex": user_id, "exclusive": True}]},
             sender=user_id,
         )
-        self.hs.get_datastore().services_cache.append(appservice)
+        self.hs.get_datastores().main.services_cache.append(appservice)
 
         whoami = self._whoami(as_token)
         self.assertEqual(
@@ -582,7 +586,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         return self.hs
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.user_id = self.register_user("kermit", "test")
         self.user_id_tok = self.login("kermit", "test")
@@ -672,7 +676,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         client_secret = "foobar"
         session_id = self._request_token(self.email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEqual(len(self.email_attempts), 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
@@ -776,7 +780,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         client_secret = "foobar"
         session_id = self._request_token(self.email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEqual(len(self.email_attempts), 1)
 
         # Attempt to add email without clicking the link
         channel = self.make_request(
@@ -977,7 +981,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         path = link.replace("https://example.com", "")
 
         channel = self.make_request("GET", path, shorthand=False)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
     def _get_link_from_email(self):
         assert self.email_attempts, "No emails have been sent"
@@ -1006,7 +1010,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         client_secret = "foobar"
         session_id = self._request_token(request_email, client_secret)
 
-        self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
+        self.assertEqual(len(self.email_attempts) - previous_email_attempts, 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
@@ -1040,3 +1044,197 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
         self.assertIn(expected_email, threepids)
+
+
+class AccountStatusTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        account.register_servlets,
+        admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        config["experimental_features"] = {"msc3720_enabled": True}
+
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+        self.requester = self.register_user("requester", "password")
+        self.requester_tok = self.login("requester", "password")
+        self.server_name = homeserver.config.server.server_name
+
+    def test_missing_mxid(self):
+        """Tests that not providing any MXID raises an error."""
+        self._test_status(
+            users=None,
+            expected_status_code=400,
+            expected_errcode=Codes.MISSING_PARAM,
+        )
+
+    def test_invalid_mxid(self):
+        """Tests that providing an invalid MXID raises an error."""
+        self._test_status(
+            users=["bad:test"],
+            expected_status_code=400,
+            expected_errcode=Codes.INVALID_PARAM,
+        )
+
+    def test_local_user_not_exists(self):
+        """Tests that the account status endpoints correctly reports that a user doesn't
+        exist.
+        """
+        user = "@unknown:" + self.hs.config.server.server_name
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": False,
+                },
+            },
+            expected_failures=[],
+        )
+
+    def test_local_user_exists(self):
+        """Tests that the account status endpoint correctly reports that a user doesn't
+        exist.
+        """
+        user = self.register_user("someuser", "password")
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                },
+            },
+            expected_failures=[],
+        )
+
+    def test_local_user_deactivated(self):
+        """Tests that the account status endpoint correctly reports a deactivated user."""
+        user = self.register_user("someuser", "password")
+        self.get_success(
+            self.hs.get_datastores().main.set_user_deactivated_status(
+                user, deactivated=True
+            )
+        )
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": True,
+                },
+            },
+            expected_failures=[],
+        )
+
+    def test_mixed_local_and_remote_users(self):
+        """Tests that if some users are remote the account status endpoint correctly
+        merges the remote responses with the local result.
+        """
+        # We use 3 users: one doesn't exist but belongs on the local homeserver, one is
+        # deactivated and belongs on one remote homeserver, and one belongs to another
+        # remote homeserver that didn't return any result (the federation code should
+        # mark that user as a failure).
+        users = [
+            "@unknown:" + self.hs.config.server.server_name,
+            "@deactivated:remote",
+            "@failed:otherremote",
+            "@bad:badremote",
+        ]
+
+        async def post_json(destination, path, data, *a, **kwa):
+            if destination == "remote":
+                return {
+                    "account_statuses": {
+                        users[1]: {
+                            "exists": True,
+                            "deactivated": True,
+                        },
+                    }
+                }
+            if destination == "otherremote":
+                return {}
+            if destination == "badremote":
+                # badremote tries to overwrite the status of a user that doesn't belong
+                # to it (i.e. users[1]) with false data, which Synapse is expected to
+                # ignore.
+                return {
+                    "account_statuses": {
+                        users[3]: {
+                            "exists": False,
+                        },
+                        users[1]: {
+                            "exists": False,
+                        },
+                    }
+                }
+
+        # Register a mock that will return the expected result depending on the remote.
+        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+
+        # Check that we've got the correct response from the client-side endpoint.
+        self._test_status(
+            users=users,
+            expected_statuses={
+                users[0]: {
+                    "exists": False,
+                },
+                users[1]: {
+                    "exists": True,
+                    "deactivated": True,
+                },
+                users[3]: {
+                    "exists": False,
+                },
+            },
+            expected_failures=[users[2]],
+        )
+
+    def _test_status(
+        self,
+        users: Optional[List[str]],
+        expected_status_code: int = 200,
+        expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
+        expected_failures: Optional[List[str]] = None,
+        expected_errcode: Optional[str] = None,
+    ):
+        """Send a request to the account status endpoint and check that the response
+        matches with what's expected.
+
+        Args:
+            users: The account(s) to request the status of, if any. If set to None, no
+                `user_id` query parameter will be included in the request.
+            expected_status_code: The expected HTTP status code.
+            expected_statuses: The expected account statuses, if any.
+            expected_failures: The expected failures, if any.
+            expected_errcode: The expected Matrix error code, if any.
+        """
+        content = {}
+        if users is not None:
+            content["user_ids"] = users
+
+        channel = self.make_request(
+            method="POST",
+            path=self.url,
+            content=content,
+            access_token=self.requester_tok,
+        )
+
+        self.assertEqual(channel.code, expected_status_code)
+
+        if expected_statuses is not None:
+            self.assertEqual(channel.json_body["account_statuses"], expected_statuses)
+
+        if expected_failures is not None:
+            self.assertEqual(channel.json_body["failures"], expected_failures)
+
+        if expected_errcode is not None:
+            self.assertEqual(channel.json_body["errcode"], expected_errcode)
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 4a68d66573..9653f45837 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -13,17 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from http import HTTPStatus
-from typing import Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.server import HomeServer
 from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
@@ -33,11 +37,11 @@ from tests.unittest import override_config, skip_unless
 
 
 class DummyRecaptchaChecker(UserInteractiveAuthChecker):
-    def __init__(self, hs):
+    def __init__(self, hs: HomeServer) -> None:
         super().__init__(hs)
-        self.recaptcha_attempts = []
+        self.recaptcha_attempts: List[Tuple[dict, str]] = []
 
-    def check_auth(self, authdict, clientip):
+    def check_auth(self, authdict: dict, clientip: str) -> Any:
         self.recaptcha_attempts.append((authdict, clientip))
         return succeed(True)
 
@@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
 
@@ -61,7 +65,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver(config=config)
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.recaptcha_checker = DummyRecaptchaChecker(hs)
         auth_handler = hs.get_auth_handler()
         auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
@@ -101,7 +105,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(len(attempts), 1)
         self.assertEqual(attempts[0][0]["response"], "a")
 
-    def test_fallback_captcha(self):
+    def test_fallback_captcha(self) -> None:
         """Ensure that fallback auth via a captcha works."""
         # Returns a 401 as per the spec
         channel = self.register(
@@ -132,7 +136,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         # We're given a registered user.
         self.assertEqual(channel.json_body["user_id"], "@user:test")
 
-    def test_complete_operation_unknown_session(self):
+    def test_complete_operation_unknown_session(self) -> None:
         """
         Attempting to mark an invalid session as complete should error.
         """
@@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
@@ -182,12 +186,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         return config
 
-    def create_resource_dict(self):
+    def create_resource_dict(self) -> Dict[str, Resource]:
         resource_dict = super().create_resource_dict()
         resource_dict.update(build_synapse_client_resource_tree(self.hs))
         return resource_dict
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
         self.device_id = "dev1"
@@ -229,7 +233,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         return channel
 
-    def test_ui_auth(self):
+    def test_ui_auth(self) -> None:
         """
         Test user interactive authentication outside of registration.
         """
@@ -259,7 +263,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_grandfathered_identifier(self):
+    def test_grandfathered_identifier(self) -> None:
         """Check behaviour without "identifier" dict
 
         Synapse used to require clients to submit a "user" field for m.login.password
@@ -286,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_can_change_body(self):
+    def test_can_change_body(self) -> None:
         """
         The client dict can be modified during the user interactive authentication session.
 
@@ -325,7 +329,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_cannot_change_uri(self):
+    def test_cannot_change_uri(self) -> None:
         """
         The initial requested URI cannot be modified during the user interactive authentication session.
         """
@@ -362,7 +366,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         )
 
     @unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
-    def test_can_reuse_session(self):
+    def test_can_reuse_session(self) -> None:
         """
         The session can be reused if configured.
 
@@ -409,7 +413,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_ui_auth_via_sso(self):
+    def test_ui_auth_via_sso(self) -> None:
         """Test a successful UI Auth flow via SSO
 
         This includes:
@@ -452,7 +456,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_does_not_offer_password_for_sso_user(self):
+    def test_does_not_offer_password_for_sso_user(self) -> None:
         login_resp = self.helper.login_via_oidc("username")
         user_tok = login_resp["access_token"]
         device_id = login_resp["device_id"]
@@ -464,7 +468,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
 
-    def test_does_not_offer_sso_for_password_user(self):
+    def test_does_not_offer_sso_for_password_user(self) -> None:
         channel = self.delete_device(
             self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
         )
@@ -474,7 +478,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_offers_both_flows_for_upgraded_user(self):
+    def test_offers_both_flows_for_upgraded_user(self) -> None:
         """A user that had a password and then logged in with SSO should get both flows"""
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
         self.assertEqual(login_resp["user_id"], self.user)
@@ -491,7 +495,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
-    def test_ui_auth_fails_for_incorrect_sso_user(self):
+    def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
         """If the user tries to authenticate with the wrong SSO user, they get an error"""
         # log the user in
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
@@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
 
@@ -548,7 +552,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": refresh_token},
         )
 
-    def is_access_token_valid(self, access_token) -> bool:
+    def is_access_token_valid(self, access_token: str) -> bool:
         """
         Checks whether an access token is valid, returning whether it is or not.
         """
@@ -561,7 +565,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
 
         return code == HTTPStatus.OK
 
-    def test_login_issue_refresh_token(self):
+    def test_login_issue_refresh_token(self) -> None:
         """
         A login response should include a refresh_token only if asked.
         """
@@ -591,7 +595,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.assertIn("refresh_token", login_with_refresh.json_body)
         self.assertIn("expires_in_ms", login_with_refresh.json_body)
 
-    def test_register_issue_refresh_token(self):
+    def test_register_issue_refresh_token(self) -> None:
         """
         A register response should include a refresh_token only if asked.
         """
@@ -627,7 +631,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.assertIn("refresh_token", register_with_refresh.json_body)
         self.assertIn("expires_in_ms", register_with_refresh.json_body)
 
-    def test_token_refresh(self):
+    def test_token_refresh(self) -> None:
         """
         A refresh token can be used to issue a new access token.
         """
@@ -665,7 +669,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         )
 
     @override_config({"refreshable_access_token_lifetime": "1m"})
-    def test_refreshable_access_token_expiration(self):
+    def test_refreshable_access_token_expiration(self) -> None:
         """
         The access token should have some time as specified in the config.
         """
@@ -722,7 +726,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "nonrefreshable_access_token_lifetime": "10m",
         }
     )
-    def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
+    def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
+        self,
+    ) -> None:
         """
         Tests that the expiry times for refreshable and non-refreshable access
         tokens can be different.
@@ -782,7 +788,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
     @override_config(
         {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
     )
-    def test_refresh_token_expiry(self):
+    def test_refresh_token_expiry(self) -> None:
         """
         The refresh token can be configured to have a limited lifetime.
         When that lifetime has ended, the refresh token can no longer be used to
@@ -834,7 +840,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "session_lifetime": "3m",
         }
     )
-    def test_ultimate_session_expiry(self):
+    def test_ultimate_session_expiry(self) -> None:
         """
         The session can be configured to have an ultimate, limited lifetime.
         """
@@ -882,7 +888,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
         )
 
-    def test_refresh_token_invalidation(self):
+    def test_refresh_token_invalidation(self) -> None:
         """Refresh tokens are invalidated after first use of the next token.
 
         A refresh token is considered invalid if:
@@ -987,7 +993,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
         )
 
-    def test_many_token_refresh(self):
+    def test_many_token_refresh(self) -> None:
         """
         If a refresh is performed many times during a session, there shouldn't be
         extra 'cruft' built up over time.
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 989e801768..d1751e1557 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -13,9 +13,13 @@
 # limitations under the License.
 from http import HTTPStatus
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.rest.client import capabilities, login
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
@@ -29,24 +33,24 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.url = b"/capabilities"
         hs = self.setup_test_homeserver()
         self.config = hs.config
         self.auth_handler = hs.get_auth_handler()
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.localpart = "user"
         self.password = "pass"
         self.user = self.register_user(self.localpart, self.password)
 
-    def test_check_auth_required(self):
+    def test_check_auth_required(self) -> None:
         channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.code, 401)
 
-    def test_get_room_version_capabilities(self):
+    def test_get_room_version_capabilities(self) -> None:
         access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
@@ -61,7 +65,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
             capabilities["m.room_versions"]["default"],
         )
 
-    def test_get_change_password_capabilities_password_login(self):
+    def test_get_change_password_capabilities_password_login(self) -> None:
         access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
@@ -71,7 +75,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertTrue(capabilities["m.change_password"]["enabled"])
 
     @override_config({"password_config": {"localdb_enabled": False}})
-    def test_get_change_password_capabilities_localdb_disabled(self):
+    def test_get_change_password_capabilities_localdb_disabled(self) -> None:
         access_token = self.get_success(
             self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
@@ -85,7 +89,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertFalse(capabilities["m.change_password"]["enabled"])
 
     @override_config({"password_config": {"enabled": False}})
-    def test_get_change_password_capabilities_password_disabled(self):
+    def test_get_change_password_capabilities_password_disabled(self) -> None:
         access_token = self.get_success(
             self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
@@ -98,7 +102,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertFalse(capabilities["m.change_password"]["enabled"])
 
-    def test_get_change_users_attributes_capabilities(self):
+    def test_get_change_users_attributes_capabilities(self) -> None:
         """Test that server returns capabilities by default."""
         access_token = self.login(self.localpart, self.password)
 
@@ -112,7 +116,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertTrue(capabilities["m.3pid_changes"]["enabled"])
 
     @override_config({"enable_set_displayname": False})
-    def test_get_set_displayname_capabilities_displayname_disabled(self):
+    def test_get_set_displayname_capabilities_displayname_disabled(self) -> None:
         """Test if set displayname is disabled that the server responds it."""
         access_token = self.login(self.localpart, self.password)
 
@@ -123,7 +127,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertFalse(capabilities["m.set_displayname"]["enabled"])
 
     @override_config({"enable_set_avatar_url": False})
-    def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+    def test_get_set_avatar_url_capabilities_avatar_url_disabled(self) -> None:
         """Test if set avatar_url is disabled that the server responds it."""
         access_token = self.login(self.localpart, self.password)
 
@@ -134,7 +138,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
 
     @override_config({"enable_3pid_changes": False})
-    def test_get_change_3pid_capabilities_3pid_disabled(self):
+    def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
         """Test if change 3pid is disabled that the server responds it."""
         access_token = self.login(self.localpart, self.password)
 
@@ -145,7 +149,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertFalse(capabilities["m.3pid_changes"]["enabled"])
 
     @override_config({"experimental_features": {"msc3244_enabled": False}})
-    def test_get_does_not_include_msc3244_fields_when_disabled(self):
+    def test_get_does_not_include_msc3244_fields_when_disabled(self) -> None:
         access_token = self.get_success(
             self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
@@ -160,7 +164,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
             "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"]
         )
 
-    def test_get_does_include_msc3244_fields_when_enabled(self):
+    def test_get_does_include_msc3244_fields_when_enabled(self) -> None:
         access_token = self.get_success(
             self.auth_handler.create_access_token_for_user_id(
                 self.user, device_id=None, valid_until_ms=None
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index fcdc565814..b1ca81a911 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -13,11 +13,16 @@
 # limitations under the License.
 
 import os
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.urls import ConsentURIBuilder
 from synapse.rest.client import login, room
 from synapse.rest.consent import consent_resource
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeSite, make_request
@@ -32,7 +37,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
     user_id = True
     hijack_auth = False
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
         config["form_secret"] = "123abc"
@@ -56,7 +61,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver(config=config)
         return hs
 
-    def test_render_public_consent(self):
+    def test_render_public_consent(self) -> None:
         """You can observe the terms form without specifying a user"""
         resource = consent_resource.ConsentResource(self.hs)
         channel = make_request(
@@ -66,9 +71,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
             "/consent?v=1",
             shorthand=False,
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
 
-    def test_accept_consent(self):
+    def test_accept_consent(self) -> None:
         """
         A user can use the consent form to accept the terms.
         """
@@ -92,7 +97,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
             access_token=access_token,
             shorthand=False,
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
 
         # Get the version from the body, and whether we've consented
         version, consented = channel.result["body"].decode("ascii").split(",")
@@ -107,7 +112,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
             access_token=access_token,
             shorthand=False,
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
 
         # Fetch the consent page, to get the consent version -- it should have
         # changed
@@ -119,7 +124,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
             access_token=access_token,
             shorthand=False,
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
 
         # Get the version from the body, and check that it's the version we
         # agreed to, and that we've consented to it.
diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py
index 16070cf027..a8af4e2435 100644
--- a/tests/rest/client/test_device_lists.py
+++ b/tests/rest/client/test_device_lists.py
@@ -11,6 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+
 from synapse.rest import admin, devices, room, sync
 from synapse.rest.client import account, login, register
 
@@ -30,7 +32,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
         devices.register_servlets,
     ]
 
-    def test_receiving_local_device_list_changes(self):
+    def test_receiving_local_device_list_changes(self) -> None:
         """Tests that a local users that share a room receive each other's device list
         changes.
         """
@@ -84,7 +86,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
             },
             access_token=alice_access_token,
         )
-        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
 
         # Check that bob's incremental sync contains the updated device list.
         # If not, the client would only receive the device list update on the
@@ -97,7 +99,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
         )
         self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
 
-    def test_not_receiving_local_device_list_changes(self):
+    def test_not_receiving_local_device_list_changes(self) -> None:
         """Tests a local users DO NOT receive device updates from each other if they do not
         share a room.
         """
@@ -119,7 +121,7 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
             "/sync",
             access_token=bob_access_token,
         )
-        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
         next_batch_token = channel.json_body["next_batch"]
 
         # ...and then an incremental sync. This should block until the sync stream is woken up,
@@ -141,11 +143,13 @@ class DeviceListsTestCase(unittest.HomeserverTestCase):
             },
             access_token=alice_access_token,
         )
-        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
 
         # Check that bob's incremental sync does not contain the updated device list.
         bob_sync_channel.await_result()
-        self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
+        self.assertEqual(
+            bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
+        )
 
         changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
             "changed", []
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 3d7aa8ec86..9fa1f82dfe 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -11,9 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventContentFields, EventTypes
 from synapse.rest import admin
 from synapse.rest.client import room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -27,7 +34,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
 
         config["enable_ephemeral_messages"] = True
@@ -35,10 +42,10 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
         self.hs = self.setup_test_homeserver(config=config)
         return self.hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.room_id = self.helper.create_room_as(self.user_id)
 
-    def test_message_expiry_no_delay(self):
+    def test_message_expiry_no_delay(self) -> None:
         """Tests that sending a message sent with a m.self_destruct_after field set to the
         past results in that event being deleted right away.
         """
@@ -61,7 +68,7 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
         event_content = self.get_event(self.room_id, event_id)["content"]
         self.assertFalse(bool(event_content), event_content)
 
-    def test_message_expiry_delay(self):
+    def test_message_expiry_delay(self) -> None:
         """Tests that sending a message with a m.self_destruct_after field set to the
         future results in that event not being deleted right away, but advancing the
         clock to after that expiry timestamp causes the event to be deleted.
@@ -89,7 +96,9 @@ class EphemeralMessageTestCase(unittest.HomeserverTestCase):
         event_content = self.get_event(self.room_id, event_id)["content"]
         self.assertFalse(bool(event_content), event_content)
 
-    def get_event(self, room_id, event_id, expected_code=200):
+    def get_event(
+        self, room_id: str, event_id: str, expected_code: int = HTTPStatus.OK
+    ) -> JsonDict:
         url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
 
         channel = self.make_request("GET", url)
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index a90294003e..1b1392fa2f 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -16,8 +16,12 @@
 
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.rest.client import events, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -32,7 +36,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
         config["enable_registration_captcha"] = False
@@ -41,11 +45,11 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(config=config)
 
-        hs.get_federation_handler = Mock()
+        hs.get_federation_handler = Mock()  # type: ignore[assignment]
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         # register an account
         self.user_id = self.register_user("sid1", "pass")
@@ -55,7 +59,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         self.other_user = self.register_user("other2", "pass")
         self.other_token = self.login(self.other_user, "pass")
 
-    def test_stream_basic_permissions(self):
+    def test_stream_basic_permissions(self) -> None:
         # invalid token, expect 401
         # note: this is in violation of the original v1 spec, which expected
         # 403. However, since the v1 spec no longer exists and the v1
@@ -65,18 +69,18 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", "/events?access_token=%s" % ("invalid" + self.token,)
         )
-        self.assertEquals(channel.code, 401, msg=channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
 
         # valid token, expect content
         channel = self.make_request(
             "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
-        self.assertEquals(channel.code, 200, msg=channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertTrue("chunk" in channel.json_body)
         self.assertTrue("start" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
-    def test_stream_room_permissions(self):
+    def test_stream_room_permissions(self) -> None:
         room_id = self.helper.create_room_as(self.other_user, tok=self.other_token)
         self.helper.send(room_id, tok=self.other_token)
 
@@ -89,10 +93,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", "/events?access_token=%s&timeout=0" % (self.token,)
         )
-        self.assertEquals(channel.code, 200, msg=channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         # We may get a presence event for ourselves down
-        self.assertEquals(
+        self.assertEqual(
             0,
             len(
                 [
@@ -111,7 +115,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
 
         # left to room (expect no content for room)
 
-    def TODO_test_stream_items(self):
+    def TODO_test_stream_items(self) -> None:
         # new user, no content
 
         # join room, expect 1 item (join)
@@ -136,7 +140,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, hs, reactor, clock):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         # register an account
         self.user_id = self.register_user("sid1", "pass")
@@ -144,7 +148,7 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
 
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
-    def test_get_event_via_events(self):
+    def test_get_event_via_events(self) -> None:
         resp = self.helper.send(self.room_id, tok=self.token)
         event_id = resp["event_id"]
 
@@ -153,4 +157,4 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
             "/events/" + event_id,
             access_token=self.token,
         )
-        self.assertEquals(channel.code, 200, msg=channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 475c6bed3d..5c31a54421 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -32,7 +32,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.filtering = hs.get_filtering()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def test_add_filter(self):
         channel = self.make_request(
@@ -45,7 +45,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body, {"filter_id": "0"})
         filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
         self.pump()
-        self.assertEquals(filter.result, self.EXAMPLE_FILTER)
+        self.assertEqual(filter.result, self.EXAMPLE_FILTER)
 
     def test_add_filter_for_other_user(self):
         channel = self.make_request(
@@ -55,7 +55,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(channel.result["code"], b"403")
-        self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
+        self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
     def test_add_filter_non_local_user(self):
         _is_mine = self.hs.is_mine
@@ -68,7 +68,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
         self.hs.is_mine = _is_mine
         self.assertEqual(channel.result["code"], b"403")
-        self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
+        self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
     def test_get_filter(self):
         filter_id = defer.ensureDeferred(
@@ -83,7 +83,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(channel.result["code"], b"200")
-        self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
+        self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
 
     def test_get_filter_non_existant(self):
         channel = self.make_request(
@@ -91,7 +91,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(channel.result["code"], b"404")
-        self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
     # Currently invalid params do not have an appropriate errcode
     # in errors.py
diff --git a/tests/rest/client/test_groups.py b/tests/rest/client/test_groups.py
index ad0425ae65..e067cf825c 100644
--- a/tests/rest/client/test_groups.py
+++ b/tests/rest/client/test_groups.py
@@ -25,13 +25,13 @@ class GroupsTestCase(unittest.HomeserverTestCase):
     servlets = [room.register_servlets, groups.register_servlets]
 
     @override_config({"enable_group_creation": True})
-    def test_rooms_limited_by_visibility(self):
+    def test_rooms_limited_by_visibility(self) -> None:
         group_id = "+spqr:test"
 
         # Alice creates a group
         channel = self.make_request("POST", "/create_group", {"localpart": "spqr"})
-        self.assertEquals(channel.code, 200, msg=channel.text_body)
-        self.assertEquals(channel.json_body, {"group_id": group_id})
+        self.assertEqual(channel.code, 200, msg=channel.text_body)
+        self.assertEqual(channel.json_body, {"group_id": group_id})
 
         # Bob creates a private room
         room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False)
@@ -45,12 +45,12 @@ class GroupsTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {}
         )
-        self.assertEquals(channel.code, 200, msg=channel.text_body)
-        self.assertEquals(channel.json_body, {})
+        self.assertEqual(channel.code, 200, msg=channel.text_body)
+        self.assertEqual(channel.json_body, {})
 
         # Alice now tries to retrieve the room list of the space.
         channel = self.make_request("GET", f"/groups/{group_id}/rooms")
-        self.assertEquals(channel.code, 200, msg=channel.text_body)
-        self.assertEquals(
+        self.assertEqual(channel.code, 200, msg=channel.text_body)
+        self.assertEqual(
             channel.json_body, {"chunk": [], "total_room_count_estimate": 0}
         )
diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index becb4e8dcc..299b9d21e2 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -13,9 +13,14 @@
 # limitations under the License.
 
 import json
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -28,7 +33,7 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
         config["enable_3pid_lookup"] = False
@@ -36,14 +41,14 @@ class IdentityTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def test_3pid_lookup_disabled(self):
+    def test_3pid_lookup_disabled(self) -> None:
         self.hs.config.registration.enable_3pid_lookup = False
 
         self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
         channel = self.make_request(b"POST", "/createRoom", b"{}", access_token=tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         room_id = channel.json_body["room_id"]
 
         params = {
@@ -56,4 +61,4 @@ class IdentityTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             b"POST", request_url, request_data, access_token=tok
         )
-        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index d7fa635eae..bbc8e74243 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -28,7 +28,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def test_rejects_device_id_ice_key_outside_of_list(self):
+    def test_rejects_device_id_ice_key_outside_of_list(self) -> None:
         self.register_user("alice", "wonderland")
         alice_token = self.login("alice", "wonderland")
         bob = self.register_user("bob", "uncle")
@@ -49,7 +49,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
             channel.result,
         )
 
-    def test_rejects_device_key_given_as_map_to_bool(self):
+    def test_rejects_device_key_given_as_map_to_bool(self) -> None:
         self.register_user("alice", "wonderland")
         alice_token = self.login("alice", "wonderland")
         bob = self.register_user("bob", "uncle")
@@ -73,7 +73,7 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
             channel.result,
         )
 
-    def test_requires_device_key(self):
+    def test_requires_device_key(self) -> None:
         """`device_keys` is required. We should complain if it's missing."""
         self.register_user("alice", "wonderland")
         alice_token = self.login("alice", "wonderland")
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 19f5e46537..090d2d0a29 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -20,6 +20,7 @@ from urllib.parse import urlencode
 
 import pymacaroons
 
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 import synapse.rest.admin
@@ -27,12 +28,15 @@ from synapse.appservice import ApplicationService
 from synapse.rest.client import devices, login, logout, register
 from synapse.rest.client.account import WhoamiRestServlet
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.server import HomeServer
 from synapse.types import create_requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
 from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.server import FakeChannel
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
@@ -95,7 +99,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.hs = self.setup_test_homeserver()
         self.hs.config.registration.enable_registration = True
         self.hs.config.registration.registrations_require_3pid = []
@@ -117,7 +121,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             }
         }
     )
-    def test_POST_ratelimiting_per_address(self):
+    def test_POST_ratelimiting_per_address(self) -> None:
         # Create different users so we're sure not to be bothered by the per-user
         # ratelimiter.
         for i in range(0, 6):
@@ -132,10 +136,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEquals(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.result["code"], b"429", channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEquals(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -150,7 +154,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
     @override_config(
         {
@@ -165,7 +169,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             }
         }
     )
-    def test_POST_ratelimiting_per_account(self):
+    def test_POST_ratelimiting_per_account(self) -> None:
         self.register_user("kermit", "monkey")
 
         for i in range(0, 6):
@@ -177,10 +181,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEquals(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.result["code"], b"429", channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEquals(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -195,7 +199,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
     @override_config(
         {
@@ -210,7 +214,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             }
         }
     )
-    def test_POST_ratelimiting_per_account_failed_attempts(self):
+    def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
         self.register_user("kermit", "monkey")
 
         for i in range(0, 6):
@@ -222,10 +226,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEquals(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.result["code"], b"429", channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEquals(channel.result["code"], b"403", channel.result)
+                self.assertEqual(channel.result["code"], b"403", channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -240,16 +244,16 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.result["code"], b"403", channel.result)
 
     @override_config({"session_lifetime": "24h"})
-    def test_soft_logout(self):
+    def test_soft_logout(self) -> None:
         self.register_user("kermit", "monkey")
 
         # we shouldn't be able to make requests without an access token
         channel = self.make_request(b"GET", TEST_URL)
-        self.assertEquals(channel.result["code"], b"401", channel.result)
-        self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
 
         # log in as normal
         params = {
@@ -259,22 +263,22 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEquals(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         access_token = channel.json_body["access_token"]
         device_id = channel.json_body["device_id"]
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 401, channel.result)
-        self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
-        self.assertEquals(channel.json_body["soft_logout"], True)
+        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+        self.assertEqual(channel.json_body["soft_logout"], True)
 
         #
         # test behaviour after deleting the expired device
@@ -286,24 +290,26 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # more requests with the expired token should still return a soft-logout
         self.reactor.advance(3600)
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 401, channel.result)
-        self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
-        self.assertEquals(channel.json_body["soft_logout"], True)
+        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+        self.assertEqual(channel.json_body["soft_logout"], True)
 
         # ... but if we delete that device, it will be a proper logout
         self._delete_device(access_token_2, "kermit", "monkey", device_id)
 
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 401, channel.result)
-        self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
-        self.assertEquals(channel.json_body["soft_logout"], False)
+        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+        self.assertEqual(channel.json_body["soft_logout"], False)
 
-    def _delete_device(self, access_token, user_id, password, device_id):
+    def _delete_device(
+        self, access_token: str, user_id: str, password: str, device_id: str
+    ) -> None:
         """Perform the UI-Auth to delete a device"""
         channel = self.make_request(
             b"DELETE", "devices/" + device_id, access_token=access_token
         )
-        self.assertEquals(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         # check it's a UI-Auth fail
         self.assertEqual(
             set(channel.json_body.keys()),
@@ -326,10 +332,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             access_token=access_token,
             content={"auth": auth},
         )
-        self.assertEquals(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
     @override_config({"session_lifetime": "24h"})
-    def test_session_can_hard_logout_after_being_soft_logged_out(self):
+    def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
         self.register_user("kermit", "monkey")
 
         # log in as normal
@@ -337,23 +343,25 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 401, channel.result)
-        self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
-        self.assertEquals(channel.json_body["soft_logout"], True)
+        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+        self.assertEqual(channel.json_body["soft_logout"], True)
 
         # Now try to hard logout this session
         channel = self.make_request(b"POST", "/logout", access_token=access_token)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
     @override_config({"session_lifetime": "24h"})
-    def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
+    def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
+        self,
+    ) -> None:
         self.register_user("kermit", "monkey")
 
         # log in as normal
@@ -361,20 +369,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEquals(channel.code, 401, channel.result)
-        self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
-        self.assertEquals(channel.json_body["soft_logout"], True)
+        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
+        self.assertEqual(channel.json_body["soft_logout"], True)
 
         # Now try to hard log out all of the user's sessions
         channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
 
 @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
@@ -432,7 +440,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
-    def test_get_login_flows(self):
+    def test_get_login_flows(self) -> None:
         """GET /login should return password and SSO flows"""
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
@@ -459,12 +467,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             ],
         )
 
-    def test_multi_sso_redirect(self):
+    def test_multi_sso_redirect(self) -> None:
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
         channel = self._make_sso_redirect_request(None)
         self.assertEqual(channel.code, 302, channel.result)
-        uri = channel.headers.getRawHeaders("Location")[0]
+        location_headers = channel.headers.getRawHeaders("Location")
+        assert location_headers
+        uri = location_headers[0]
 
         # hitting that picker should give us some HTML
         channel = self.make_request("GET", uri)
@@ -487,7 +497,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
         self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
 
-    def test_multi_sso_redirect_to_cas(self):
+    def test_multi_sso_redirect_to_cas(self) -> None:
         """If CAS is chosen, should redirect to the CAS server"""
 
         channel = self.make_request(
@@ -514,7 +524,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         service_uri_params = urllib.parse.parse_qs(service_uri_query)
         self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
 
-    def test_multi_sso_redirect_to_saml(self):
+    def test_multi_sso_redirect_to_saml(self) -> None:
         """If SAML is chosen, should redirect to the SAML server"""
         channel = self.make_request(
             "GET",
@@ -536,7 +546,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         relay_state_param = saml_uri_params["RelayState"][0]
         self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
 
-    def test_login_via_oidc(self):
+    def test_login_via_oidc(self) -> None:
         """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
 
         # pick the default OIDC provider
@@ -604,7 +614,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(chan.code, 200, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@user1:test")
 
-    def test_multi_sso_redirect_to_unknown(self):
+    def test_multi_sso_redirect_to_unknown(self) -> None:
         """An unknown IdP should cause a 400"""
         channel = self.make_request(
             "GET",
@@ -612,23 +622,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 400, channel.result)
 
-    def test_client_idp_redirect_to_unknown(self):
+    def test_client_idp_redirect_to_unknown(self) -> None:
         """If the client tries to pick an unknown IdP, return a 404"""
         channel = self._make_sso_redirect_request("xxx")
         self.assertEqual(channel.code, 404, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
-    def test_client_idp_redirect_to_oidc(self):
+    def test_client_idp_redirect_to_oidc(self) -> None:
         """If the client pick a known IdP, redirect to it"""
         channel = self._make_sso_redirect_request("oidc")
         self.assertEqual(channel.code, 302, channel.result)
-        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        location_headers = channel.headers.getRawHeaders("Location")
+        assert location_headers
+        oidc_uri = location_headers[0]
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
 
-    def _make_sso_redirect_request(self, idp_prov: Optional[str] = None):
+    def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
         """Send a request to /_matrix/client/r0/login/sso/redirect
 
         ... possibly specifying an IDP provider
@@ -659,7 +671,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.base_url = "https://matrix.goodserver.com/"
         self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
 
@@ -675,7 +687,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         cas_user_id = "username"
         self.user_id = "@%s:test" % cas_user_id
 
-        async def get_raw(uri, args):
+        async def get_raw(uri: str, args: Any) -> bytes:
             """Return an example response payload from a call to the `/proxyValidate`
             endpoint of a CAS server, copied from
             https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
@@ -709,10 +721,10 @@ class CASTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.deactivate_account_handler = hs.get_deactivate_account_handler()
 
-    def test_cas_redirect_confirm(self):
+    def test_cas_redirect_confirm(self) -> None:
         """Tests that the SSO login flow serves a confirmation page before redirecting a
         user to the redirect URL.
         """
@@ -754,15 +766,15 @@ class CASTestCase(unittest.HomeserverTestCase):
             }
         }
     )
-    def test_cas_redirect_whitelisted(self):
+    def test_cas_redirect_whitelisted(self) -> None:
         """Tests that the SSO login flow serves a redirect to a whitelisted url"""
         self._test_redirect("https://legit-site.com/")
 
     @override_config({"public_baseurl": "https://example.com"})
-    def test_cas_redirect_login_fallback(self):
+    def test_cas_redirect_login_fallback(self) -> None:
         self._test_redirect("https://example.com/_matrix/static/client/login")
 
-    def _test_redirect(self, redirect_url):
+    def _test_redirect(self, redirect_url: str) -> None:
         """Tests that the SSO login flow serves a redirect for the given redirect URL."""
         cas_ticket_url = (
             "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
@@ -778,7 +790,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
 
     @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
-    def test_deactivated_user(self):
+    def test_deactivated_user(self) -> None:
         """Logging in as a deactivated account should error."""
         redirect_url = "https://legit-site.com/"
 
@@ -821,7 +833,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
         "algorithm": jwt_algorithm,
     }
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # If jwt_config has been defined (eg via @override_config), don't replace it.
@@ -837,23 +849,23 @@ class JWTTestCase(unittest.HomeserverTestCase):
             return result.decode("ascii")
         return result
 
-    def jwt_login(self, *args):
+    def jwt_login(self, *args: Any) -> FakeChannel:
         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
-    def test_login_jwt_valid_registered(self):
+    def test_login_jwt_valid_registered(self) -> None:
         self.register_user("kermit", "monkey")
         channel = self.jwt_login({"sub": "kermit"})
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
-    def test_login_jwt_valid_unregistered(self):
+    def test_login_jwt_valid_unregistered(self) -> None:
         channel = self.jwt_login({"sub": "frog"})
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@frog:test")
 
-    def test_login_jwt_invalid_signature(self):
+    def test_login_jwt_invalid_signature(self) -> None:
         channel = self.jwt_login({"sub": "frog"}, "notsecret")
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -862,7 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
             "JWT validation failed: Signature verification failed",
         )
 
-    def test_login_jwt_expired(self):
+    def test_login_jwt_expired(self) -> None:
         channel = self.jwt_login({"sub": "frog", "exp": 864000})
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -870,7 +882,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"], "JWT validation failed: Signature has expired"
         )
 
-    def test_login_jwt_not_before(self):
+    def test_login_jwt_not_before(self) -> None:
         now = int(time.time())
         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
         self.assertEqual(channel.result["code"], b"403", channel.result)
@@ -880,14 +892,14 @@ class JWTTestCase(unittest.HomeserverTestCase):
             "JWT validation failed: The token is not yet valid (nbf)",
         )
 
-    def test_login_no_sub(self):
+    def test_login_no_sub(self) -> None:
         channel = self.jwt_login({"username": "root"})
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
     @override_config({"jwt_config": {**base_config, "issuer": "test-issuer"}})
-    def test_login_iss(self):
+    def test_login_iss(self) -> None:
         """Test validating the issuer claim."""
         # A valid issuer.
         channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
@@ -911,14 +923,14 @@ class JWTTestCase(unittest.HomeserverTestCase):
             'JWT validation failed: Token is missing the "iss" claim',
         )
 
-    def test_login_iss_no_config(self):
+    def test_login_iss_no_config(self) -> None:
         """Test providing an issuer claim without requiring it in the configuration."""
         channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
-    def test_login_aud(self):
+    def test_login_aud(self) -> None:
         """Test validating the audience claim."""
         # A valid audience.
         channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
@@ -942,7 +954,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
             'JWT validation failed: Token is missing the "aud" claim',
         )
 
-    def test_login_aud_no_config(self):
+    def test_login_aud_no_config(self) -> None:
         """Test providing an audience without requiring it in the configuration."""
         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
         self.assertEqual(channel.result["code"], b"403", channel.result)
@@ -951,20 +963,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"], "JWT validation failed: Invalid audience"
         )
 
-    def test_login_default_sub(self):
+    def test_login_default_sub(self) -> None:
         """Test reading user ID from the default subject claim."""
         channel = self.jwt_login({"sub": "kermit"})
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
-    def test_login_custom_sub(self):
+    def test_login_custom_sub(self) -> None:
         """Test reading user ID from a custom subject claim."""
         channel = self.jwt_login({"username": "frog"})
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@frog:test")
 
-    def test_login_no_token(self):
+    def test_login_no_token(self) -> None:
         params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         self.assertEqual(channel.result["code"], b"403", channel.result)
@@ -1026,7 +1038,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         ]
     )
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["jwt_config"] = {
             "enabled": True,
@@ -1042,17 +1054,17 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
             return result.decode("ascii")
         return result
 
-    def jwt_login(self, *args):
+    def jwt_login(self, *args: Any) -> FakeChannel:
         params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
         channel = self.make_request(b"POST", LOGIN_URL, params)
         return channel
 
-    def test_login_jwt_valid(self):
+    def test_login_jwt_valid(self) -> None:
         channel = self.jwt_login({"sub": "kermit"})
         self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
-    def test_login_jwt_invalid_signature(self):
+    def test_login_jwt_invalid_signature(self) -> None:
         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
         self.assertEqual(channel.result["code"], b"403", channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -1071,7 +1083,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.hs = self.setup_test_homeserver()
 
         self.service = ApplicationService(
@@ -1101,11 +1113,11 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        self.hs.get_datastore().services_cache.append(self.service)
-        self.hs.get_datastore().services_cache.append(self.another_service)
+        self.hs.get_datastores().main.services_cache.append(self.service)
+        self.hs.get_datastores().main.services_cache.append(self.another_service)
         return self.hs
 
-    def test_login_appservice_user(self):
+    def test_login_appservice_user(self) -> None:
         """Test that an appservice user can use /login"""
         self.register_appservice_user(AS_USER, self.service.token)
 
@@ -1117,9 +1129,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
-    def test_login_appservice_user_bot(self):
+    def test_login_appservice_user_bot(self) -> None:
         """Test that the appservice bot can use /login"""
         self.register_appservice_user(AS_USER, self.service.token)
 
@@ -1131,9 +1143,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
-    def test_login_appservice_wrong_user(self):
+    def test_login_appservice_wrong_user(self) -> None:
         """Test that non-as users cannot login with the as token"""
         self.register_appservice_user(AS_USER, self.service.token)
 
@@ -1145,9 +1157,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.result["code"], b"403", channel.result)
 
-    def test_login_appservice_wrong_as(self):
+    def test_login_appservice_wrong_as(self) -> None:
         """Test that as users cannot login with wrong as token"""
         self.register_appservice_user(AS_USER, self.service.token)
 
@@ -1159,9 +1171,9 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.another_service.token
         )
 
-        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.result["code"], b"403", channel.result)
 
-    def test_login_appservice_no_token(self):
+    def test_login_appservice_no_token(self) -> None:
         """Test that users must provide a token when using the appservice
         login method
         """
@@ -1173,7 +1185,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
 
 
 @skip_unless(HAS_OIDC, "requires OIDC")
@@ -1182,7 +1194,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
 
     servlets = [login.register_servlets]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
 
@@ -1202,7 +1214,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         d.update(build_synapse_client_resource_tree(self.hs))
         return d
 
-    def test_username_picker(self):
+    def test_username_picker(self) -> None:
         """Test the happy path of a username picker flow."""
 
         # do the start of the login flow
diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3cf5871899..3a74d2e96c 100644
--- a/tests/rest/client/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
@@ -13,11 +13,16 @@
 # limitations under the License.
 
 import json
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes
 from synapse.rest import admin
 from synapse.rest.client import account, login, password_policy, register
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -46,7 +51,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
         account.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.register_url = "/_matrix/client/r0/register"
         self.policy = {
             "enabled": True,
@@ -65,12 +70,12 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver(config=config)
         return hs
 
-    def test_get_policy(self):
+    def test_get_policy(self) -> None:
         """Tests if the /password_policy endpoint returns the configured policy."""
 
         channel = self.make_request("GET", "/_matrix/client/r0/password_policy")
 
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
         self.assertEqual(
             channel.json_body,
             {
@@ -83,70 +88,70 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
             channel.result,
         )
 
-    def test_password_too_short(self):
+    def test_password_too_short(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "shorty"})
         channel = self.make_request("POST", self.register_url, request_data)
 
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         self.assertEqual(
             channel.json_body["errcode"],
             Codes.PASSWORD_TOO_SHORT,
             channel.result,
         )
 
-    def test_password_no_digit(self):
+    def test_password_no_digit(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
         channel = self.make_request("POST", self.register_url, request_data)
 
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         self.assertEqual(
             channel.json_body["errcode"],
             Codes.PASSWORD_NO_DIGIT,
             channel.result,
         )
 
-    def test_password_no_symbol(self):
+    def test_password_no_symbol(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
         channel = self.make_request("POST", self.register_url, request_data)
 
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         self.assertEqual(
             channel.json_body["errcode"],
             Codes.PASSWORD_NO_SYMBOL,
             channel.result,
         )
 
-    def test_password_no_uppercase(self):
+    def test_password_no_uppercase(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
         channel = self.make_request("POST", self.register_url, request_data)
 
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         self.assertEqual(
             channel.json_body["errcode"],
             Codes.PASSWORD_NO_UPPERCASE,
             channel.result,
         )
 
-    def test_password_no_lowercase(self):
+    def test_password_no_lowercase(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
         channel = self.make_request("POST", self.register_url, request_data)
 
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         self.assertEqual(
             channel.json_body["errcode"],
             Codes.PASSWORD_NO_LOWERCASE,
             channel.result,
         )
 
-    def test_password_compliant(self):
+    def test_password_compliant(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
         channel = self.make_request("POST", self.register_url, request_data)
 
         # Getting a 401 here means the password has passed validation and the server has
         # responded with a list of registration flows.
-        self.assertEqual(channel.code, 401, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
 
-    def test_password_change(self):
+    def test_password_change(self) -> None:
         """This doesn't test every possible use case, only that hitting /account/password
         triggers the password validation code.
         """
@@ -173,5 +178,5 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
             access_token=tok,
         )
 
-        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
         self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT)
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index c0de4c93a8..27dcfc83d2 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -11,11 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import Codes
 from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
@@ -30,12 +35,12 @@ class PowerLevelsTestCase(HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
 
         return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # register a room admin, moderator and regular user
         self.admin_user_id = self.register_user("admin", "pass")
         self.admin_access_token = self.login("admin", "pass")
@@ -88,7 +93,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
             tok=self.admin_access_token,
         )
 
-    def test_non_admins_cannot_enable_room_encryption(self):
+    def test_non_admins_cannot_enable_room_encryption(self) -> None:
         # have the mod try to enable room encryption
         self.helper.send_state(
             self.room_id,
@@ -104,10 +109,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
             "m.room.encryption",
             {"algorithm": "m.megolm.v1.aes-sha2"},
             tok=self.user_access_token,
-            expect_code=403,  # expect failure
+            expect_code=HTTPStatus.FORBIDDEN,  # expect failure
         )
 
-    def test_non_admins_cannot_send_server_acl(self):
+    def test_non_admins_cannot_send_server_acl(self) -> None:
         # have the mod try to send a server ACL
         self.helper.send_state(
             self.room_id,
@@ -118,7 +123,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
                 "deny": ["*.evil.com", "evil.com"],
             },
             tok=self.mod_access_token,
-            expect_code=403,  # expect failure
+            expect_code=HTTPStatus.FORBIDDEN,  # expect failure
         )
 
         # have the user try to send a server ACL
@@ -131,10 +136,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
                 "deny": ["*.evil.com", "evil.com"],
             },
             tok=self.user_access_token,
-            expect_code=403,  # expect failure
+            expect_code=HTTPStatus.FORBIDDEN,  # expect failure
         )
 
-    def test_non_admins_cannot_tombstone_room(self):
+    def test_non_admins_cannot_tombstone_room(self) -> None:
         # Create another room that will serve as our "upgraded room"
         self.upgraded_room_id = self.helper.create_room_as(
             self.admin_user_id, tok=self.admin_access_token
@@ -149,7 +154,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
                 "replacement_room": self.upgraded_room_id,
             },
             tok=self.mod_access_token,
-            expect_code=403,  # expect failure
+            expect_code=HTTPStatus.FORBIDDEN,  # expect failure
         )
 
         # have the user try to send a tombstone event
@@ -164,17 +169,17 @@ class PowerLevelsTestCase(HomeserverTestCase):
             expect_code=403,  # expect failure
         )
 
-    def test_admins_can_enable_room_encryption(self):
+    def test_admins_can_enable_room_encryption(self) -> None:
         # have the admin try to enable room encryption
         self.helper.send_state(
             self.room_id,
             "m.room.encryption",
             {"algorithm": "m.megolm.v1.aes-sha2"},
             tok=self.admin_access_token,
-            expect_code=200,  # expect success
+            expect_code=HTTPStatus.OK,  # expect success
         )
 
-    def test_admins_can_send_server_acl(self):
+    def test_admins_can_send_server_acl(self) -> None:
         # have the admin try to send a server ACL
         self.helper.send_state(
             self.room_id,
@@ -185,10 +190,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
                 "deny": ["*.evil.com", "evil.com"],
             },
             tok=self.admin_access_token,
-            expect_code=200,  # expect success
+            expect_code=HTTPStatus.OK,  # expect success
         )
 
-    def test_admins_can_tombstone_room(self):
+    def test_admins_can_tombstone_room(self) -> None:
         # Create another room that will serve as our "upgraded room"
         self.upgraded_room_id = self.helper.create_room_as(
             self.admin_user_id, tok=self.admin_access_token
@@ -203,10 +208,10 @@ class PowerLevelsTestCase(HomeserverTestCase):
                 "replacement_room": self.upgraded_room_id,
             },
             tok=self.admin_access_token,
-            expect_code=200,  # expect success
+            expect_code=HTTPStatus.OK,  # expect success
         )
 
-    def test_cannot_set_string_power_levels(self):
+    def test_cannot_set_string_power_levels(self) -> None:
         room_power_levels = self.helper.get_state(
             self.room_id,
             "m.room.power_levels",
@@ -221,7 +226,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
             "m.room.power_levels",
             room_power_levels,
             tok=self.admin_access_token,
-            expect_code=400,  # expect failure
+            expect_code=HTTPStatus.BAD_REQUEST,  # expect failure
         )
 
         self.assertEqual(
@@ -230,7 +235,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
             body,
         )
 
-    def test_cannot_set_unsafe_large_power_levels(self):
+    def test_cannot_set_unsafe_large_power_levels(self) -> None:
         room_power_levels = self.helper.get_state(
             self.room_id,
             "m.room.power_levels",
@@ -247,7 +252,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
             "m.room.power_levels",
             room_power_levels,
             tok=self.admin_access_token,
-            expect_code=400,  # expect failure
+            expect_code=HTTPStatus.BAD_REQUEST,  # expect failure
         )
 
         self.assertEqual(
@@ -256,7 +261,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
             body,
         )
 
-    def test_cannot_set_unsafe_small_power_levels(self):
+    def test_cannot_set_unsafe_small_power_levels(self) -> None:
         room_power_levels = self.helper.get_state(
             self.room_id,
             "m.room.power_levels",
@@ -273,7 +278,7 @@ class PowerLevelsTestCase(HomeserverTestCase):
             "m.room.power_levels",
             room_power_levels,
             tok=self.admin_access_token,
-            expect_code=400,  # expect failure
+            expect_code=HTTPStatus.BAD_REQUEST,  # expect failure
         )
 
         self.assertEqual(
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 56fe1a3d01..0abe378fe4 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -11,14 +11,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from http import HTTPStatus
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.handlers.presence import PresenceHandler
 from synapse.rest.client import presence
+from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -31,7 +34,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
     user = UserID.from_string(user_id)
     servlets = [presence.register_servlets]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         presence_handler = Mock(spec=PresenceHandler)
         presence_handler.set_state.return_value = defer.succeed(None)
@@ -45,7 +48,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_put_presence(self):
+    def test_put_presence(self) -> None:
         """
         PUT to the status endpoint with use_presence enabled will call
         set_state on the presence handler.
@@ -57,11 +60,11 @@ class PresenceTestCase(unittest.HomeserverTestCase):
             "PUT", "/presence/%s/status" % (self.user_id,), body
         )
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
 
     @unittest.override_config({"use_presence": False})
-    def test_put_presence_disabled(self):
+    def test_put_presence_disabled(self) -> None:
         """
         PUT to the status endpoint with use_presence disabled will NOT call
         set_state on the presence handler.
@@ -72,5 +75,5 @@ class PresenceTestCase(unittest.HomeserverTestCase):
             "PUT", "/presence/%s/status" % (self.user_id,), body
         )
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index ead883ded8..77c3ced42e 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -13,12 +13,16 @@
 # limitations under the License.
 
 """Tests REST events for /profile paths."""
-from typing import Any, Dict
+from typing import Any, Dict, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import Codes
 from synapse.rest import admin
 from synapse.rest.client import login, profile, room
+from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -32,20 +36,20 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.hs = self.setup_test_homeserver()
         return self.hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.owner = self.register_user("owner", "pass")
         self.owner_tok = self.login("owner", "pass")
         self.other = self.register_user("other", "pass", displayname="Bob")
 
-    def test_get_displayname(self):
+    def test_get_displayname(self) -> None:
         res = self._get_displayname()
         self.assertEqual(res, "owner")
 
-    def test_set_displayname(self):
+    def test_set_displayname(self) -> None:
         channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
@@ -57,7 +61,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         res = self._get_displayname()
         self.assertEqual(res, "test")
 
-    def test_set_displayname_noauth(self):
+    def test_set_displayname_noauth(self) -> None:
         channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.owner,),
@@ -65,7 +69,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 401, channel.result)
 
-    def test_set_displayname_too_long(self):
+    def test_set_displayname_too_long(self) -> None:
         """Attempts to set a stupid displayname should get a 400"""
         channel = self.make_request(
             "PUT",
@@ -78,11 +82,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         res = self._get_displayname()
         self.assertEqual(res, "owner")
 
-    def test_get_displayname_other(self):
+    def test_get_displayname_other(self) -> None:
         res = self._get_displayname(self.other)
-        self.assertEquals(res, "Bob")
+        self.assertEqual(res, "Bob")
 
-    def test_set_displayname_other(self):
+    def test_set_displayname_other(self) -> None:
         channel = self.make_request(
             "PUT",
             "/profile/%s/displayname" % (self.other,),
@@ -91,11 +95,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 400, channel.result)
 
-    def test_get_avatar_url(self):
+    def test_get_avatar_url(self) -> None:
         res = self._get_avatar_url()
         self.assertIsNone(res)
 
-    def test_set_avatar_url(self):
+    def test_set_avatar_url(self) -> None:
         channel = self.make_request(
             "PUT",
             "/profile/%s/avatar_url" % (self.owner,),
@@ -107,7 +111,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         res = self._get_avatar_url()
         self.assertEqual(res, "http://my.server/pic.gif")
 
-    def test_set_avatar_url_noauth(self):
+    def test_set_avatar_url_noauth(self) -> None:
         channel = self.make_request(
             "PUT",
             "/profile/%s/avatar_url" % (self.owner,),
@@ -115,7 +119,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 401, channel.result)
 
-    def test_set_avatar_url_too_long(self):
+    def test_set_avatar_url_too_long(self) -> None:
         """Attempts to set a stupid avatar_url should get a 400"""
         channel = self.make_request(
             "PUT",
@@ -128,11 +132,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         res = self._get_avatar_url()
         self.assertIsNone(res)
 
-    def test_get_avatar_url_other(self):
+    def test_get_avatar_url_other(self) -> None:
         res = self._get_avatar_url(self.other)
         self.assertIsNone(res)
 
-    def test_set_avatar_url_other(self):
+    def test_set_avatar_url_other(self) -> None:
         channel = self.make_request(
             "PUT",
             "/profile/%s/avatar_url" % (self.other,),
@@ -141,14 +145,14 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 400, channel.result)
 
-    def _get_displayname(self, name=None):
+    def _get_displayname(self, name: Optional[str] = None) -> str:
         channel = self.make_request(
             "GET", "/profile/%s/displayname" % (name or self.owner,)
         )
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body["displayname"]
 
-    def _get_avatar_url(self, name=None):
+    def _get_avatar_url(self, name: Optional[str] = None) -> str:
         channel = self.make_request(
             "GET", "/profile/%s/avatar_url" % (name or self.owner,)
         )
@@ -156,7 +160,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         return channel.json_body.get("avatar_url")
 
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_size_limit_global(self):
+    def test_avatar_size_limit_global(self) -> None:
         """Tests that the maximum size limit for avatars is enforced when updating a
         global profile.
         """
@@ -187,7 +191,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_size_limit_per_room(self):
+    def test_avatar_size_limit_per_room(self) -> None:
         """Tests that the maximum size limit for avatars is enforced when updating a
         per-room profile.
         """
@@ -220,7 +224,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
     @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
-    def test_avatar_allowed_mime_type_global(self):
+    def test_avatar_allowed_mime_type_global(self) -> None:
         """Tests that the MIME type whitelist for avatars is enforced when updating a
         global profile.
         """
@@ -251,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
     @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
-    def test_avatar_allowed_mime_type_per_room(self):
+    def test_avatar_allowed_mime_type_per_room(self) -> None:
         """Tests that the MIME type whitelist for avatars is enforced when updating a
         per-room profile.
         """
@@ -283,7 +287,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 200, channel.result)
 
-    def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
+    def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
         """Stores metadata about files in the database.
 
         Args:
@@ -292,7 +296,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
                 properties are "mimetype" (for the file's type) and "size" (for the
                 file's size).
         """
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         for name, props in names_and_props.items():
             self.get_success(
@@ -316,8 +320,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["require_auth_for_profile_requests"] = True
         config["limit_profile_requests_to_users_who_share_rooms"] = True
@@ -325,7 +328,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # User owning the requested profile.
         self.owner = self.register_user("owner", "pass")
         self.owner_tok = self.login("owner", "pass")
@@ -337,22 +340,24 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
         self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok)
 
-    def test_no_auth(self):
+    def test_no_auth(self) -> None:
         self.try_fetch_profile(401)
 
-    def test_not_in_shared_room(self):
+    def test_not_in_shared_room(self) -> None:
         self.ensure_requester_left_room()
 
         self.try_fetch_profile(403, access_token=self.requester_tok)
 
-    def test_in_shared_room(self):
+    def test_in_shared_room(self) -> None:
         self.ensure_requester_left_room()
 
         self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok)
 
         self.try_fetch_profile(200, self.requester_tok)
 
-    def try_fetch_profile(self, expected_code, access_token=None):
+    def try_fetch_profile(
+        self, expected_code: int, access_token: Optional[str] = None
+    ) -> None:
         self.request_profile(expected_code, access_token=access_token)
 
         self.request_profile(
@@ -363,13 +368,18 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
             expected_code, url_suffix="/avatar_url", access_token=access_token
         )
 
-    def request_profile(self, expected_code, url_suffix="", access_token=None):
+    def request_profile(
+        self,
+        expected_code: int,
+        url_suffix: str = "",
+        access_token: Optional[str] = None,
+    ) -> None:
         channel = self.make_request(
             "GET", self.profile_url + url_suffix, access_token=access_token
         )
         self.assertEqual(channel.code, expected_code, channel.result)
 
-    def ensure_requester_left_room(self):
+    def ensure_requester_left_room(self) -> None:
         try:
             self.helper.leave(
                 room=self.room_id, user=self.requester, tok=self.requester_tok
@@ -389,7 +399,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
         profile.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["require_auth_for_profile_requests"] = True
         config["limit_profile_requests_to_users_who_share_rooms"] = True
@@ -397,12 +407,12 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # User requesting the profile.
         self.requester = self.register_user("requester", "pass")
         self.requester_tok = self.login("requester", "pass")
 
-    def test_can_lookup_own_profile(self):
+    def test_can_lookup_own_profile(self) -> None:
         """Tests that a user can lookup their own profile without having to be in a room
         if 'require_auth_for_profile_requests' is set to true in the server's config.
         """
diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index d0ce91ccd9..4f875b9289 100644
--- a/tests/rest/client/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
@@ -27,7 +27,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def test_enabled_on_creation(self):
+    def test_enabled_on_creation(self) -> None:
         """
         Tests the GET and PUT of push rules' `enabled` endpoints.
         Tests that a rule is enabled upon creation, even though a rule with that
@@ -56,7 +56,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["enabled"], True)
 
-    def test_enabled_on_recreation(self):
+    def test_enabled_on_recreation(self) -> None:
         """
         Tests the GET and PUT of push rules' `enabled` endpoints.
         Tests that a rule is enabled upon creation, even if a rule with that
@@ -113,7 +113,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["enabled"], True)
 
-    def test_enabled_disable(self):
+    def test_enabled_disable(self) -> None:
         """
         Tests the GET and PUT of push rules' `enabled` endpoints.
         Tests that a rule is disabled and enabled when we ask for it.
@@ -166,7 +166,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["enabled"], True)
 
-    def test_enabled_404_when_get_non_existent(self):
+    def test_enabled_404_when_get_non_existent(self) -> None:
         """
         Tests that `enabled` gives 404 when the rule doesn't exist.
         """
@@ -212,7 +212,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
-    def test_enabled_404_when_get_non_existent_server_rule(self):
+    def test_enabled_404_when_get_non_existent_server_rule(self) -> None:
         """
         Tests that `enabled` gives 404 when the server-default rule doesn't exist.
         """
@@ -226,7 +226,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
-    def test_enabled_404_when_put_non_existent_rule(self):
+    def test_enabled_404_when_put_non_existent_rule(self) -> None:
         """
         Tests that `enabled` gives 404 when we put to a rule that doesn't exist.
         """
@@ -243,7 +243,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
-    def test_enabled_404_when_put_non_existent_server_rule(self):
+    def test_enabled_404_when_put_non_existent_server_rule(self) -> None:
         """
         Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist.
         """
@@ -260,7 +260,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
-    def test_actions_get(self):
+    def test_actions_get(self) -> None:
         """
         Tests that `actions` gives you what you expect on a fresh rule.
         """
@@ -289,7 +289,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
             channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}]
         )
 
-    def test_actions_put(self):
+    def test_actions_put(self) -> None:
         """
         Tests that PUT on actions updates the value you'd get from GET.
         """
@@ -325,7 +325,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["actions"], ["dont_notify"])
 
-    def test_actions_404_when_get_non_existent(self):
+    def test_actions_404_when_get_non_existent(self) -> None:
         """
         Tests that `actions` gives 404 when the rule doesn't exist.
         """
@@ -365,7 +365,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
-    def test_actions_404_when_get_non_existent_server_rule(self):
+    def test_actions_404_when_get_non_existent_server_rule(self) -> None:
         """
         Tests that `actions` gives 404 when the server-default rule doesn't exist.
         """
@@ -379,7 +379,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
-    def test_actions_404_when_put_non_existent_rule(self):
+    def test_actions_404_when_put_non_existent_rule(self) -> None:
         """
         Tests that `actions` gives 404 when putting to a rule that doesn't exist.
         """
@@ -396,7 +396,7 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
         self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
-    def test_actions_404_when_put_non_existent_server_rule(self):
+    def test_actions_404_when_put_non_existent_server_rule(self) -> None:
         """
         Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist.
         """
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 433d715f69..7401b5e0c0 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -11,9 +11,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
@@ -28,7 +34,7 @@ class RedactionsTestCase(HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
 
         config["rc_message"] = {"per_second": 0.2, "burst_count": 10}
@@ -36,7 +42,7 @@ class RedactionsTestCase(HomeserverTestCase):
 
         return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # register a couple of users
         self.mod_user_id = self.register_user("user1", "pass")
         self.mod_access_token = self.login("user1", "pass")
@@ -60,7 +66,9 @@ class RedactionsTestCase(HomeserverTestCase):
             room=self.room_id, user=self.other_user_id, tok=self.other_access_token
         )
 
-    def _redact_event(self, access_token, room_id, event_id, expect_code=200):
+    def _redact_event(
+        self, access_token: str, room_id: str, event_id: str, expect_code: int = 200
+    ) -> JsonDict:
         """Helper function to send a redaction event.
 
         Returns the json body.
@@ -71,13 +79,13 @@ class RedactionsTestCase(HomeserverTestCase):
         self.assertEqual(int(channel.result["code"]), expect_code)
         return channel.json_body
 
-    def _sync_room_timeline(self, access_token, room_id):
+    def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
         channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
         self.assertEqual(channel.result["code"], b"200")
         room_sync = channel.json_body["rooms"]["join"][room_id]
         return room_sync["timeline"]["events"]
 
-    def test_redact_event_as_moderator(self):
+    def test_redact_event_as_moderator(self) -> None:
         # as a regular user, send a message to redact
         b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
         msg_id = b["event_id"]
@@ -98,7 +106,7 @@ class RedactionsTestCase(HomeserverTestCase):
         self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id)
         self.assertEqual(timeline[-2]["content"], {})
 
-    def test_redact_event_as_normal(self):
+    def test_redact_event_as_normal(self) -> None:
         # as a regular user, send a message to redact
         b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
         normal_msg_id = b["event_id"]
@@ -133,7 +141,7 @@ class RedactionsTestCase(HomeserverTestCase):
         self.assertEqual(timeline[-3]["unsigned"]["redacted_by"], redaction_id)
         self.assertEqual(timeline[-3]["content"], {})
 
-    def test_redact_nonexistent_event(self):
+    def test_redact_nonexistent_event(self) -> None:
         # control case: an existing event
         b = self.helper.send(room_id=self.room_id, tok=self.other_access_token)
         msg_id = b["event_id"]
@@ -158,7 +166,7 @@ class RedactionsTestCase(HomeserverTestCase):
         self.assertEqual(timeline[-2]["unsigned"]["redacted_by"], redaction_id)
         self.assertEqual(timeline[-2]["content"], {})
 
-    def test_redact_create_event(self):
+    def test_redact_create_event(self) -> None:
         # control case: an existing event
         b = self.helper.send(room_id=self.room_id, tok=self.mod_access_token)
         msg_id = b["event_id"]
@@ -178,7 +186,7 @@ class RedactionsTestCase(HomeserverTestCase):
             self.other_access_token, self.room_id, create_event_id, expect_code=403
         )
 
-    def test_redact_event_as_moderator_ratelimit(self):
+    def test_redact_event_as_moderator_ratelimit(self) -> None:
         """Tests that the correct ratelimiting is applied to redactions"""
 
         message_ids = []
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 0f1c47dcbb..9aebf1735a 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -16,15 +16,21 @@
 import datetime
 import json
 import os
+from typing import Any, Dict, List, Tuple
 
 import pkg_resources
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.server import HomeServer
 from synapse.storage._base import db_to_json
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
@@ -39,12 +45,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     ]
     url = b"/_matrix/client/r0/register"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["allow_guest_access"] = True
         return config
 
-    def test_POST_appservice_registration_valid(self):
+    def test_POST_appservice_registration_valid(self) -> None:
         user_id = "@as_user_kermit:test"
         as_token = "i_am_an_app_service"
 
@@ -56,7 +62,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             sender="@as:test",
         )
 
-        self.hs.get_datastore().services_cache.append(appservice)
+        self.hs.get_datastores().main.services_cache.append(appservice)
         request_data = json.dumps(
             {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
         )
@@ -65,11 +71,11 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
         det_data = {"user_id": user_id, "home_server": self.hs.hostname}
         self.assertDictContainsSubset(det_data, channel.json_body)
 
-    def test_POST_appservice_registration_no_type(self):
+    def test_POST_appservice_registration_no_type(self) -> None:
         as_token = "i_am_an_app_service"
 
         appservice = ApplicationService(
@@ -80,16 +86,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             sender="@as:test",
         )
 
-        self.hs.get_datastore().services_cache.append(appservice)
+        self.hs.get_datastores().main.services_cache.append(appservice)
         request_data = json.dumps({"username": "as_user_kermit"})
 
         channel = self.make_request(
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
-        self.assertEquals(channel.result["code"], b"400", channel.result)
+        self.assertEqual(channel.result["code"], b"400", channel.result)
 
-    def test_POST_appservice_registration_invalid(self):
+    def test_POST_appservice_registration_invalid(self) -> None:
         self.appservice = None  # no application service exists
         request_data = json.dumps(
             {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
@@ -98,23 +104,23 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
 
-    def test_POST_bad_password(self):
+    def test_POST_bad_password(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": 666})
         channel = self.make_request(b"POST", self.url, request_data)
 
-        self.assertEquals(channel.result["code"], b"400", channel.result)
-        self.assertEquals(channel.json_body["error"], "Invalid password")
+        self.assertEqual(channel.result["code"], b"400", channel.result)
+        self.assertEqual(channel.json_body["error"], "Invalid password")
 
-    def test_POST_bad_username(self):
+    def test_POST_bad_username(self) -> None:
         request_data = json.dumps({"username": 777, "password": "monkey"})
         channel = self.make_request(b"POST", self.url, request_data)
 
-        self.assertEquals(channel.result["code"], b"400", channel.result)
-        self.assertEquals(channel.json_body["error"], "Invalid username")
+        self.assertEqual(channel.result["code"], b"400", channel.result)
+        self.assertEqual(channel.json_body["error"], "Invalid username")
 
-    def test_POST_user_valid(self):
+    def test_POST_user_valid(self) -> None:
         user_id = "@kermit:test"
         device_id = "frogfone"
         params = {
@@ -131,58 +137,58 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "home_server": self.hs.hostname,
             "device_id": device_id,
         }
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
     @override_config({"enable_registration": False})
-    def test_POST_disabled_registration(self):
+    def test_POST_disabled_registration(self) -> None:
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
 
         channel = self.make_request(b"POST", self.url, request_data)
 
-        self.assertEquals(channel.result["code"], b"403", channel.result)
-        self.assertEquals(channel.json_body["error"], "Registration has been disabled")
-        self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["error"], "Registration has been disabled")
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
 
-    def test_POST_guest_registration(self):
+    def test_POST_guest_registration(self) -> None:
         self.hs.config.key.macaroon_secret_key = "test"
         self.hs.config.registration.allow_guest_access = True
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
-    def test_POST_disabled_guest_registration(self):
+    def test_POST_disabled_guest_registration(self) -> None:
         self.hs.config.registration.allow_guest_access = False
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
-        self.assertEquals(channel.result["code"], b"403", channel.result)
-        self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["error"], "Guest access is disabled")
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
-    def test_POST_ratelimiting_guest(self):
+    def test_POST_ratelimiting_guest(self) -> None:
         for i in range(0, 6):
             url = self.url + b"?kind=guest"
             channel = self.make_request(b"POST", url, b"{}")
 
             if i == 5:
-                self.assertEquals(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.result["code"], b"429", channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEquals(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.result["code"], b"200", channel.result)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
-    def test_POST_ratelimiting(self):
+    def test_POST_ratelimiting(self) -> None:
         for i in range(0, 6):
             params = {
                 "username": "kermit" + str(i),
@@ -194,23 +200,23 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", self.url, request_data)
 
             if i == 5:
-                self.assertEquals(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.result["code"], b"429", channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEquals(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.result["code"], b"200", channel.result)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_requires_token(self):
+    def test_POST_registration_requires_token(self) -> None:
         username = "kermit"
         device_id = "frogfone"
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
@@ -223,7 +229,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
-        params = {
+        params: JsonDict = {
             "username": username,
             "password": "monkey",
             "device_id": device_id,
@@ -231,7 +237,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         # Request without auth to get flows and session
         channel = self.make_request(b"POST", self.url, json.dumps(params))
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
         # Synapse adds a dummy stage to differentiate flows where otherwise one
         # flow would be a subset of another flow.
@@ -249,7 +255,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         }
         request_data = json.dumps(params)
         channel = self.make_request(b"POST", self.url, request_data)
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
         completed = channel.json_body["completed"]
         self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
 
@@ -265,7 +271,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "home_server": self.hs.hostname,
             "device_id": device_id,
         }
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
         # Check the `completed` counter has been incremented and pending is 0
@@ -276,12 +282,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 retcols=["pending", "completed"],
             )
         )
-        self.assertEquals(res["completed"], 1)
-        self.assertEquals(res["pending"], 0)
+        self.assertEqual(res["completed"], 1)
+        self.assertEqual(res["pending"], 0)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_invalid(self):
-        params = {
+    def test_POST_registration_token_invalid(self) -> None:
+        params: JsonDict = {
             "username": "kermit",
             "password": "monkey",
         }
@@ -295,28 +301,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "session": session,
         }
         channel = self.make_request(b"POST", self.url, json.dumps(params))
-        self.assertEquals(channel.result["code"], b"401", channel.result)
-        self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
-        self.assertEquals(channel.json_body["completed"], [])
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
+        self.assertEqual(channel.json_body["completed"], [])
 
         # Test with non-string (invalid)
         params["auth"]["token"] = 1234
         channel = self.make_request(b"POST", self.url, json.dumps(params))
-        self.assertEquals(channel.result["code"], b"401", channel.result)
-        self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
-        self.assertEquals(channel.json_body["completed"], [])
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+        self.assertEqual(channel.json_body["completed"], [])
 
         # Test with unknown token (invalid)
         params["auth"]["token"] = "1234"
         channel = self.make_request(b"POST", self.url, json.dumps(params))
-        self.assertEquals(channel.result["code"], b"401", channel.result)
-        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
-        self.assertEquals(channel.json_body["completed"], [])
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEqual(channel.json_body["completed"], [])
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_limit_uses(self):
+    def test_POST_registration_token_limit_uses(self) -> None:
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         # Create token that can be used once
         self.get_success(
             store.db_pool.simple_insert(
@@ -330,8 +336,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
-        params1 = {"username": "bert", "password": "monkey"}
-        params2 = {"username": "ernie", "password": "monkey"}
+        params1: JsonDict = {"username": "bert", "password": "monkey"}
+        params2: JsonDict = {"username": "ernie", "password": "monkey"}
         # Do 2 requests without auth to get two session IDs
         channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
         session1 = channel1.json_body["session"]
@@ -354,7 +360,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 retcol="pending",
             )
         )
-        self.assertEquals(pending, 1)
+        self.assertEqual(pending, 1)
 
         # Check auth fails when using token with session2
         params2["auth"] = {
@@ -363,9 +369,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "session": session2,
         }
         channel = self.make_request(b"POST", self.url, json.dumps(params2))
-        self.assertEquals(channel.result["code"], b"401", channel.result)
-        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
-        self.assertEquals(channel.json_body["completed"], [])
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEqual(channel.json_body["completed"], [])
 
         # Complete registration with session1
         params1["auth"]["type"] = LoginType.DUMMY
@@ -378,20 +384,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 retcols=["pending", "completed"],
             )
         )
-        self.assertEquals(res["pending"], 0)
-        self.assertEquals(res["completed"], 1)
+        self.assertEqual(res["pending"], 0)
+        self.assertEqual(res["completed"], 1)
 
         # Check auth still fails when using token with session2
         channel = self.make_request(b"POST", self.url, json.dumps(params2))
-        self.assertEquals(channel.result["code"], b"401", channel.result)
-        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
-        self.assertEquals(channel.json_body["completed"], [])
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEqual(channel.json_body["completed"], [])
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_expiry(self):
+    def test_POST_registration_token_expiry(self) -> None:
         token = "abcd"
         now = self.hs.get_clock().time_msec()
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         # Create token that expired yesterday
         self.get_success(
             store.db_pool.simple_insert(
@@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
-        params = {"username": "kermit", "password": "monkey"}
+        params: JsonDict = {"username": "kermit", "password": "monkey"}
         # Request without auth to get session
         channel = self.make_request(b"POST", self.url, json.dumps(params))
         session = channel.json_body["session"]
@@ -417,9 +423,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "session": session,
         }
         channel = self.make_request(b"POST", self.url, json.dumps(params))
-        self.assertEquals(channel.result["code"], b"401", channel.result)
-        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
-        self.assertEquals(channel.json_body["completed"], [])
+        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEqual(channel.json_body["completed"], [])
 
         # Update token so it expires tomorrow
         self.get_success(
@@ -436,10 +442,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_session_expiry(self):
+    def test_POST_registration_token_session_expiry(self) -> None:
         """Test `pending` is decremented when an uncompleted session expires."""
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
@@ -454,8 +460,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         )
 
         # Do 2 requests without auth to get two session IDs
-        params1 = {"username": "bert", "password": "monkey"}
-        params2 = {"username": "ernie", "password": "monkey"}
+        params1: JsonDict = {"username": "bert", "password": "monkey"}
+        params2: JsonDict = {"username": "ernie", "password": "monkey"}
         channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
         session1 = channel1.json_body["session"]
         channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
@@ -504,7 +510,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 retcol="result",
             )
         )
-        self.assertEquals(db_to_json(result2), token)
+        self.assertEqual(db_to_json(result2), token)
 
         # Delete both sessions (mimics expiry)
         self.get_success(
@@ -519,10 +525,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
                 retcol="pending",
             )
         )
-        self.assertEquals(pending, 0)
+        self.assertEqual(pending, 0)
 
     @override_config({"registration_requires_token": True})
-    def test_POST_registration_token_session_expiry_deleted_token(self):
+    def test_POST_registration_token_session_expiry_deleted_token(self) -> None:
         """Test session expiry doesn't break when the token is deleted.
 
         1. Start but don't complete UIA with a registration token
@@ -530,7 +536,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         3. Expire the session
         """
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
@@ -545,7 +551,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         )
 
         # Do request without auth to get a session ID
-        params = {"username": "kermit", "password": "monkey"}
+        params: JsonDict = {"username": "kermit", "password": "monkey"}
         channel = self.make_request(b"POST", self.url, json.dumps(params))
         session = channel.json_body["session"]
 
@@ -570,9 +576,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
         )
 
-    def test_advertised_flows(self):
+    def test_advertised_flows(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
         # with the stock config, we only expect the dummy flow
@@ -593,9 +599,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_advertised_flows_captcha_and_terms_and_3pids(self):
+    def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
         self.assertCountEqual(
@@ -625,9 +631,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_advertised_flows_no_msisdn_email_required(self):
+    def test_advertised_flows_no_msisdn_email_required(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
         flows = channel.json_body["flows"]
 
         # with the stock config, we expect all four combinations of 3pid
@@ -646,7 +652,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_request_token_existing_email_inhibit_error(self):
+    def test_request_token_existing_email_inhibit_error(self) -> None:
         """Test that requesting a token via this endpoint doesn't leak existing
         associations if configured that way.
         """
@@ -657,7 +663,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         # Add a threepid
         self.get_success(
-            self.hs.get_datastore().user_add_threepid(
+            self.hs.get_datastores().main.user_add_threepid(
                 user_id=user_id,
                 medium="email",
                 address=email,
@@ -671,7 +677,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"register/email/requestToken",
             {"client_secret": "foobar", "email": email, "send_attempt": 1},
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         self.assertIsNotNone(channel.json_body.get("sid"))
 
@@ -685,7 +691,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             },
         }
     )
-    def test_reject_invalid_email(self):
+    def test_reject_invalid_email(self) -> None:
         """Check that bad emails are rejected"""
 
         # Test for email with multiple @
@@ -694,9 +700,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"register/email/requestToken",
             {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
         )
-        self.assertEquals(400, channel.code, channel.result)
+        self.assertEqual(400, channel.code, channel.result)
         # Check error to ensure that we're not erroring due to a bug in the test.
-        self.assertEquals(
+        self.assertEqual(
             channel.json_body,
             {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
         )
@@ -707,8 +713,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"register/email/requestToken",
             {"client_secret": "foobar", "email": "email", "send_attempt": 1},
         )
-        self.assertEquals(400, channel.code, channel.result)
-        self.assertEquals(
+        self.assertEqual(400, channel.code, channel.result)
+        self.assertEqual(
             channel.json_body,
             {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
         )
@@ -720,8 +726,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"register/email/requestToken",
             {"client_secret": "foobar", "email": email, "send_attempt": 1},
         )
-        self.assertEquals(400, channel.code, channel.result)
-        self.assertEquals(
+        self.assertEqual(400, channel.code, channel.result)
+        self.assertEqual(
             channel.json_body,
             {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
         )
@@ -731,7 +737,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "inhibit_user_in_use_error": True,
         }
     )
-    def test_inhibit_user_in_use_error(self):
+    def test_inhibit_user_in_use_error(self) -> None:
         """Tests that the 'inhibit_user_in_use_error' configuration flag behaves
         correctly.
         """
@@ -745,7 +751,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         # Check that /available correctly ignores the username provided despite the
         # username being already registered.
         channel = self.make_request("GET", "register/available?username=" + username)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # Test that when starting a UIA registration flow the request doesn't fail because
         # of a conflicting username
@@ -779,7 +785,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         account_validity.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         # Test for account expiring after a week.
         config["enable_registration"] = True
@@ -791,7 +797,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def test_validity_period(self):
+    def test_validity_period(self) -> None:
         self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -799,18 +805,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         # endpoint.
         channel = self.make_request(b"GET", "/sync", access_token=tok)
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
 
         channel = self.make_request(b"GET", "/sync", access_token=tok)
 
-        self.assertEquals(channel.result["code"], b"403", channel.result)
-        self.assertEquals(
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
         )
 
-    def test_manual_renewal(self):
+    def test_manual_renewal(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -826,14 +832,14 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         params = {"user_id": user_id}
         request_data = json.dumps(params)
         channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
         channel = self.make_request(b"GET", "/sync", access_token=tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
-    def test_manual_expire(self):
+    def test_manual_expire(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -848,17 +854,17 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         }
         request_data = json.dumps(params)
         channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
         channel = self.make_request(b"GET", "/sync", access_token=tok)
-        self.assertEquals(channel.result["code"], b"403", channel.result)
-        self.assertEquals(
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
         )
 
-    def test_logging_out_expired_user(self):
+    def test_logging_out_expired_user(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -873,18 +879,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         }
         request_data = json.dumps(params)
         channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # Try to log the user out
         channel = self.make_request(b"POST", "/logout", access_token=tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # Log the user in again (allowed for expired accounts)
         tok = self.login("kermit", "monkey")
 
         # Try to log out all of the user's sessions
         channel = self.make_request(b"POST", "/logout/all", access_token=tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
 
 class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -898,7 +904,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         account.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
 
         # Test for account expiring after a week and renewal emails being sent 2
@@ -935,17 +941,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
         self.hs = self.setup_test_homeserver(config=config)
 
-        async def sendmail(*args, **kwargs):
+        async def sendmail(*args: Any, **kwargs: Any) -> None:
             self.email_attempts.append((args, kwargs))
 
-        self.email_attempts = []
+        self.email_attempts: List[Tuple[Any, Any]] = []
         self.hs.get_send_email_handler()._sendmail = sendmail
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         return self.hs
 
-    def test_renewal_email(self):
+    def test_renewal_email(self) -> None:
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
@@ -959,7 +965,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
         url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
         channel = self.make_request(b"GET", url)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # Check that we're getting HTML back.
         content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -977,7 +983,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # Move 1 day forward. Try to renew with the same token again.
         url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
         channel = self.make_request(b"GET", url)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         # Check that we're getting HTML back.
         content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -997,14 +1003,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # succeed.
         self.reactor.advance(datetime.timedelta(days=3).total_seconds())
         channel = self.make_request(b"GET", "/sync", access_token=tok)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
-    def test_renewal_invalid_token(self):
+    def test_renewal_invalid_token(self) -> None:
         # Hit the renewal endpoint with an invalid token and check that it behaves as
         # expected, i.e. that it responds with 404 Not Found and the correct HTML.
         url = "/_matrix/client/unstable/account_validity/renew?token=123"
         channel = self.make_request(b"GET", url)
-        self.assertEquals(channel.result["code"], b"404", channel.result)
+        self.assertEqual(channel.result["code"], b"404", channel.result)
 
         # Check that we're getting HTML back.
         content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1019,7 +1025,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             channel.result["body"], expected_html.encode("utf8"), channel.result
         )
 
-    def test_manual_email_send(self):
+    def test_manual_email_send(self) -> None:
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
@@ -1028,11 +1034,11 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
         )
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         self.assertEqual(len(self.email_attempts), 1)
 
-    def test_deactivated_user(self):
+    def test_deactivated_user(self) -> None:
         self.email_attempts = []
 
         (user_id, tok) = self.create_user()
@@ -1056,7 +1062,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(len(self.email_attempts), 0)
 
-    def create_user(self):
+    def create_user(self) -> Tuple[str, str]:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
         # We need to manually add an email address otherwise the handler will do
@@ -1073,7 +1079,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         )
         return user_id, tok
 
-    def test_manual_email_send_expired_account(self):
+    def test_manual_email_send_expired_account(self) -> None:
         user_id = self.register_user("kermit", "monkey")
         tok = self.login("kermit", "monkey")
 
@@ -1103,7 +1109,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
         )
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         self.assertEqual(len(self.email_attempts), 1)
 
@@ -1112,7 +1118,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
     servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.validity_period = 10
         self.max_delta = self.validity_period * 10.0 / 100.0
 
@@ -1126,14 +1132,16 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
         # We need to set these directly, instead of in the homeserver config dict above.
         # This is due to account validity-related config options not being read by
         # Synapse when account_validity.enabled is False.
-        self.hs.get_datastore()._account_validity_period = self.validity_period
-        self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
+        self.hs.get_datastores().main._account_validity_period = self.validity_period
+        self.hs.get_datastores().main._account_validity_startup_job_max_delta = (
+            self.max_delta
+        )
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         return self.hs
 
-    def test_background_job(self):
+    def test_background_job(self) -> None:
         """
         Tests the same thing as test_background_job, except that it sets the
         startup_job_max_delta parameter and checks that the expiration date is within the
@@ -1156,14 +1164,14 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
     servlets = [register.register_servlets]
     url = "/_matrix/client/v1/register/m.login.registration_token/validity"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["registration_requires_token"] = True
         return config
 
-    def test_GET_token_valid(self):
+    def test_GET_token_valid(self) -> None:
         token = "abcd"
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "registration_tokens",
@@ -1181,22 +1189,22 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             b"GET",
             f"{self.url}?token={token}",
         )
-        self.assertEquals(channel.result["code"], b"200", channel.result)
-        self.assertEquals(channel.json_body["valid"], True)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["valid"], True)
 
-    def test_GET_token_invalid(self):
+    def test_GET_token_invalid(self) -> None:
         token = "1234"
         channel = self.make_request(
             b"GET",
             f"{self.url}?token={token}",
         )
-        self.assertEquals(channel.result["code"], b"200", channel.result)
-        self.assertEquals(channel.json_body["valid"], False)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["valid"], False)
 
     @override_config(
         {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
     )
-    def test_GET_ratelimiting(self):
+    def test_GET_ratelimiting(self) -> None:
         token = "1234"
 
         for i in range(0, 6):
@@ -1206,10 +1214,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             )
 
             if i == 5:
-                self.assertEquals(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.result["code"], b"429", channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEquals(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.result["code"], b"200", channel.result)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
@@ -1217,4 +1225,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             b"GET",
             f"{self.url}?token={token}",
         )
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index dfd9ffcb93..c8db45719e 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -18,11 +18,15 @@ import urllib.parse
 from typing import Dict, List, Optional, Tuple
 from unittest.mock import patch
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
+from synapse.server import HomeServer
 from synapse.storage.relations import RelationPaginationToken
 from synapse.types import JsonDict, StreamToken
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -52,8 +56,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
 
         self.user_id, self.user_token = self._create_user("alice")
         self.user2_id, self.user2_token = self._create_user("bob")
@@ -63,13 +67,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
         self.parent_id = res["event_id"]
 
-    def test_send_relation(self):
+    def test_send_relation(self) -> None:
         """Tests that sending a relation using the new /send_relation works
         creates the right shape of event.
         """
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         event_id = channel.json_body["event_id"]
 
@@ -78,7 +82,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "/rooms/%s/event/%s" % (self.room, event_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         self.assert_dict(
             {
@@ -95,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel.json_body,
         )
 
-    def test_deny_invalid_event(self):
+    def test_deny_invalid_event(self) -> None:
         """Test that we deny relations on non-existant events"""
         channel = self._send_relation(
             RelationTypes.ANNOTATION,
@@ -103,11 +107,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
         )
-        self.assertEquals(400, channel.code, channel.json_body)
+        self.assertEqual(400, channel.code, channel.json_body)
 
         # Unless that event is referenced from another event!
         self.get_success(
-            self.hs.get_datastore().db_pool.simple_insert(
+            self.hs.get_datastores().main.db_pool.simple_insert(
                 table="event_relations",
                 values={
                     "event_id": "bar",
@@ -123,9 +127,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
-    def test_deny_invalid_room(self):
+    def test_deny_invalid_room(self) -> None:
         """Test that we deny relations on non-existant events"""
         # Create another room and send a message in it.
         room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
@@ -136,17 +140,17 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
         )
-        self.assertEquals(400, channel.code, channel.json_body)
+        self.assertEqual(400, channel.code, channel.json_body)
 
-    def test_deny_double_react(self):
+    def test_deny_double_react(self) -> None:
         """Test that we deny relations on membership events"""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEquals(400, channel.code, channel.json_body)
+        self.assertEqual(400, channel.code, channel.json_body)
 
-    def test_deny_forked_thread(self):
+    def test_deny_forked_thread(self) -> None:
         """It is invalid to start a thread off a thread."""
         channel = self._send_relation(
             RelationTypes.THREAD,
@@ -154,7 +158,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=self.parent_id,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         parent_id = channel.json_body["event_id"]
 
         channel = self._send_relation(
@@ -163,16 +167,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=parent_id,
         )
-        self.assertEquals(400, channel.code, channel.json_body)
+        self.assertEqual(400, channel.code, channel.json_body)
 
-    def test_basic_paginate_relations(self):
+    def test_basic_paginate_relations(self) -> None:
         """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         first_annotation_id = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         second_annotation_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -180,11 +184,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # We expect to get back a single pagination result, which is the latest
         # full relation event we sent above.
-        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
         self.assert_dict(
             {
                 "event_id": second_annotation_id,
@@ -195,7 +199,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
         # We also expect to get the original event (the id of which is self.parent_id)
-        self.assertEquals(
+        self.assertEqual(
             channel.json_body["original_event"]["event_id"], self.parent_id
         )
 
@@ -212,11 +216,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # We expect to get back a single pagination result, which is the earliest
         # full relation event we sent above.
-        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
         self.assert_dict(
             {
                 "event_id": first_annotation_id,
@@ -235,7 +239,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             ).to_string(self.store)
         )
 
-    def test_repeated_paginate_relations(self):
+    def test_repeated_paginate_relations(self) -> None:
         """Test that if we paginate using a limit and tokens then we get the
         expected events.
         """
@@ -245,7 +249,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel = self._send_relation(
                 RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
         prev_token = ""
@@ -260,12 +264,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
                 access_token=self.user_token,
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
 
             found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
             next_batch = channel.json_body.get("next_batch")
 
-            self.assertNotEquals(prev_token, next_batch)
+            self.assertNotEqual(prev_token, next_batch)
             prev_token = next_batch
 
             if not prev_token:
@@ -273,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         # We paginated backwards, so reverse
         found_event_ids.reverse()
-        self.assertEquals(found_event_ids, expected_event_ids)
+        self.assertEqual(found_event_ids, expected_event_ids)
 
         # Reset and try again, but convert the tokens to the legacy format.
         prev_token = ""
@@ -288,12 +292,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
                 access_token=self.user_token,
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
 
             found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
             next_batch = channel.json_body.get("next_batch")
 
-            self.assertNotEquals(prev_token, next_batch)
+            self.assertNotEqual(prev_token, next_batch)
             prev_token = next_batch
 
             if not prev_token:
@@ -301,12 +305,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         # We paginated backwards, so reverse
         found_event_ids.reverse()
-        self.assertEquals(found_event_ids, expected_event_ids)
+        self.assertEqual(found_event_ids, expected_event_ids)
 
-    def test_pagination_from_sync_and_messages(self):
+    def test_pagination_from_sync_and_messages(self) -> None:
         """Pagination tokens from /sync and /messages can be used to paginate /relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         annotation_id = channel.json_body["event_id"]
         # Send an event after the relation events.
         self.helper.send(self.room, body="Latest event", tok=self.user_token)
@@ -319,7 +323,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", f"/sync?filter={filter}", access_token=self.user_token
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
         sync_prev_batch = room_timeline["prev_batch"]
         self.assertIsNotNone(sync_prev_batch)
@@ -335,7 +339,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/messages?dir=b&limit=1",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         messages_end = channel.json_body["end"]
         self.assertIsNotNone(messages_end)
         # Ensure the relation event is not in the chunk returned from /messages.
@@ -355,14 +359,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
                 access_token=self.user_token,
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
 
             # The relation should be in the returned chunk.
             self.assertIn(
                 annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
             )
 
-    def test_aggregation_pagination_groups(self):
+    def test_aggregation_pagination_groups(self) -> None:
         """Test that we can paginate annotation groups correctly."""
 
         # We need to create ten separate users to send each reaction.
@@ -386,7 +390,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 key=key,
                 access_token=access_tokens[idx],
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
 
             idx += 1
             idx %= len(access_tokens)
@@ -404,7 +408,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 % (self.room, self.parent_id, from_token),
                 access_token=self.user_token,
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
 
             self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
 
@@ -419,15 +423,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             next_batch = channel.json_body.get("next_batch")
 
-            self.assertNotEquals(prev_token, next_batch)
+            self.assertNotEqual(prev_token, next_batch)
             prev_token = next_batch
 
             if not prev_token:
                 break
 
-        self.assertEquals(sent_groups, found_groups)
+        self.assertEqual(sent_groups, found_groups)
 
-    def test_aggregation_pagination_within_group(self):
+    def test_aggregation_pagination_within_group(self) -> None:
         """Test that we can paginate within an annotation group."""
 
         # We need to create ten separate users to send each reaction.
@@ -449,14 +453,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 key="👍",
                 access_token=access_tokens[idx],
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
             idx += 1
 
         # Also send a different type of reaction so that we test we don't see it
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         prev_token = ""
         found_event_ids: List[str] = []
@@ -473,7 +477,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 f"/m.reaction/{encoded_key}?limit=1{from_token}",
                 access_token=self.user_token,
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
 
             self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
 
@@ -481,7 +485,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             next_batch = channel.json_body.get("next_batch")
 
-            self.assertNotEquals(prev_token, next_batch)
+            self.assertNotEqual(prev_token, next_batch)
             prev_token = next_batch
 
             if not prev_token:
@@ -489,7 +493,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         # We paginated backwards, so reverse
         found_event_ids.reverse()
-        self.assertEquals(found_event_ids, expected_event_ids)
+        self.assertEqual(found_event_ids, expected_event_ids)
 
         # Reset and try again, but convert the tokens to the legacy format.
         prev_token = ""
@@ -506,7 +510,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 f"/m.reaction/{encoded_key}?limit=1{from_token}",
                 access_token=self.user_token,
             )
-            self.assertEquals(200, channel.code, channel.json_body)
+            self.assertEqual(200, channel.code, channel.json_body)
 
             self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
 
@@ -514,7 +518,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             next_batch = channel.json_body.get("next_batch")
 
-            self.assertNotEquals(prev_token, next_batch)
+            self.assertNotEqual(prev_token, next_batch)
             prev_token = next_batch
 
             if not prev_token:
@@ -522,21 +526,21 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         # We paginated backwards, so reverse
         found_event_ids.reverse()
-        self.assertEquals(found_event_ids, expected_event_ids)
+        self.assertEqual(found_event_ids, expected_event_ids)
 
-    def test_aggregation(self):
+    def test_aggregation(self) -> None:
         """Test that annotations get correctly aggregated."""
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
@@ -544,9 +548,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, self.parent_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
-        self.assertEquals(
+        self.assertEqual(
             channel.json_body,
             {
                 "chunk": [
@@ -556,17 +560,17 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_aggregation_redactions(self):
+    def test_aggregation_redactions(self) -> None:
         """Test that annotations get correctly aggregated after a redaction."""
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         to_redact_event_id = channel.json_body["event_id"]
 
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Now lets redact one of the 'a' reactions
         channel = self.make_request(
@@ -575,7 +579,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             access_token=self.user_token,
             content={},
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
@@ -583,14 +587,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, self.parent_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
-        self.assertEquals(
+        self.assertEqual(
             channel.json_body,
             {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
         )
 
-    def test_aggregation_must_be_annotation(self):
+    def test_aggregation_must_be_annotation(self) -> None:
         """Test that aggregations must be annotations."""
 
         channel = self.make_request(
@@ -599,12 +603,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, self.parent_id, RelationTypes.REPLACE),
             access_token=self.user_token,
         )
-        self.assertEquals(400, channel.code, channel.json_body)
+        self.assertEqual(400, channel.code, channel.json_body)
 
     @unittest.override_config(
         {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
     )
-    def test_bundled_aggregations(self):
+    def test_bundled_aggregations(self) -> None:
         """
         Test that annotations, references, and threads get correctly bundled.
 
@@ -615,29 +619,29 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         """
         # Setup by sending a variety of relations.
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         reply_1 = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         reply_2 = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_2 = channel.json_body["event_id"]
 
         def assert_bundle(event_json: JsonDict) -> None:
@@ -655,7 +659,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             )
 
             # Check the values of each field.
-            self.assertEquals(
+            self.assertEqual(
                 {
                     "chunk": [
                         {"type": "m.reaction", "key": "a", "count": 2},
@@ -665,12 +669,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 relations_dict[RelationTypes.ANNOTATION],
             )
 
-            self.assertEquals(
+            self.assertEqual(
                 {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
                 relations_dict[RelationTypes.REFERENCE],
             )
 
-            self.assertEquals(
+            self.assertEqual(
                 2,
                 relations_dict[RelationTypes.THREAD].get("count"),
             )
@@ -701,7 +705,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/event/{self.parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(channel.json_body)
 
         # Request the room messages.
@@ -710,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/messages?dir=b",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
 
         # Request the room context.
@@ -719,12 +723,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/context/{self.parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(channel.json_body["event"])
 
         # Request sync.
         channel = self.make_request("GET", "/sync", access_token=self.user_token)
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
         self.assertTrue(room_timeline["limited"])
         assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
@@ -737,7 +741,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             content={"search_categories": {"room_events": {"search_term": "Hi"}}},
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         chunk = [
             result["result"]
             for result in channel.json_body["search_categories"]["room_events"][
@@ -746,47 +750,47 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         ]
         assert_bundle(self._find_event_in_chunk(chunk))
 
-    def test_aggregation_get_event_for_annotation(self):
+    def test_aggregation_get_event_for_annotation(self) -> None:
         """Test that annotations do not get bundled aggregations included
         when directly requested.
         """
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         annotation_id = channel.json_body["event_id"]
 
         # Annotate the annotation.
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
             f"/rooms/{self.room}/event/{annotation_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
 
-    def test_aggregation_get_event_for_thread(self):
+    def test_aggregation_get_event_for_thread(self) -> None:
         """Test that threads get bundled aggregations included when directly requested."""
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_id = channel.json_body["event_id"]
 
         # Annotate the annotation.
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
             f"/rooms/{self.room}/event/{thread_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
-        self.assertEquals(
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(
             channel.json_body["unsigned"].get("m.relations"),
             {
                 RelationTypes.ANNOTATION: {
@@ -801,11 +805,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         self.assertEqual(len(channel.json_body["chunk"]), 1)
 
         thread_message = channel.json_body["chunk"][0]
-        self.assertEquals(
+        self.assertEqual(
             thread_message["unsigned"].get("m.relations"),
             {
                 RelationTypes.ANNOTATION: {
@@ -815,7 +819,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
-    def test_ignore_invalid_room(self):
+    def test_ignore_invalid_room(self) -> None:
         """Test that we ignore invalid relations over federation."""
         # Create another room and send a message in it.
         room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
@@ -905,7 +909,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
         # And when fetching aggregations.
@@ -914,7 +918,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
         # And for bundled aggregations.
@@ -923,11 +927,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{room2}/event/{parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         self.assertNotIn("m.relations", channel.json_body["unsigned"])
 
     @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
-    def test_edit(self):
+    def test_edit(self) -> None:
         """Test that a simple edit works."""
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -936,7 +940,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         edit_event_id = channel.json_body["event_id"]
 
@@ -958,8 +962,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/event/{self.parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
-        self.assertEquals(channel.json_body["content"], new_body)
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(channel.json_body["content"], new_body)
         assert_bundle(channel.json_body)
 
         # Request the room messages.
@@ -968,7 +972,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/messages?dir=b",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
 
         # Request the room context.
@@ -977,7 +981,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/context/{self.parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(channel.json_body["event"])
 
         # Request sync, but limit the timeline so it becomes limited (and includes
@@ -988,7 +992,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", f"/sync?filter={filter}", access_token=self.user_token
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
         self.assertTrue(room_timeline["limited"])
         assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
@@ -1001,7 +1005,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             content={"search_categories": {"room_events": {"search_term": "Hi"}}},
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         chunk = [
             result["result"]
             for result in channel.json_body["search_categories"]["room_events"][
@@ -1010,7 +1014,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         ]
         assert_bundle(self._find_event_in_chunk(chunk))
 
-    def test_multi_edit(self):
+    def test_multi_edit(self) -> None:
         """Test that multiple edits, including attempts by people who
         shouldn't be allowed, are correctly handled.
         """
@@ -1024,7 +1028,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
         channel = self._send_relation(
@@ -1032,7 +1036,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         edit_event_id = channel.json_body["event_id"]
 
@@ -1045,16 +1049,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
             },
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
-        self.assertEquals(channel.json_body["content"], new_body)
+        self.assertEqual(channel.json_body["content"], new_body)
 
         relations_dict = channel.json_body["unsigned"].get("m.relations")
         self.assertIn(RelationTypes.REPLACE, relations_dict)
@@ -1067,7 +1071,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
-    def test_edit_reply(self):
+    def test_edit_reply(self) -> None:
         """Test that editing a reply works."""
 
         # Create a reply to edit.
@@ -1076,7 +1080,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A reply!"},
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         reply = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1086,7 +1090,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=reply,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         edit_event_id = channel.json_body["event_id"]
 
@@ -1095,7 +1099,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "/rooms/%s/event/%s" % (self.room, reply),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # We expect to see the new body in the dict, as well as the reference
         # metadata sill intact.
@@ -1124,7 +1128,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
 
     @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
-    def test_edit_thread(self):
+    def test_edit_thread(self) -> None:
         """Test that editing a thread works."""
 
         # Create a thread and edit the last event.
@@ -1133,7 +1137,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A threaded reply!"},
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         threaded_event_id = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1143,7 +1147,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=threaded_event_id,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Fetch the thread root, to get the bundled aggregation for the thread.
         channel = self.make_request(
@@ -1151,7 +1155,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/rooms/{self.room}/event/{self.parent_id}",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # We expect that the edit message appears in the thread summary in the
         # unsigned relations section.
@@ -1161,11 +1165,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         thread_summary = relations_dict[RelationTypes.THREAD]
         self.assertIn("latest_event", thread_summary)
         latest_event_in_thread = thread_summary["latest_event"]
-        self.assertEquals(
-            latest_event_in_thread["content"]["body"], "I've been edited!"
-        )
+        self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
 
-    def test_edit_edit(self):
+    def test_edit_edit(self) -> None:
         """Test that an edit cannot be edited."""
         new_body = {"msgtype": "m.text", "body": "Initial edit"}
         channel = self._send_relation(
@@ -1177,7 +1179,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 "m.new_content": new_body,
             },
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         edit_event_id = channel.json_body["event_id"]
 
         # Edit the edit event.
@@ -1191,7 +1193,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             },
             parent_id=edit_event_id,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Request the original event.
         channel = self.make_request(
@@ -1199,9 +1201,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         # The edit to the edit should be ignored.
-        self.assertEquals(channel.json_body["content"], new_body)
+        self.assertEqual(channel.json_body["content"], new_body)
 
         # The relations information should not include the edit to the edit.
         relations_dict = channel.json_body["unsigned"].get("m.relations")
@@ -1215,7 +1217,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
-    def test_relations_redaction_redacts_edits(self):
+    def test_relations_redaction_redacts_edits(self) -> None:
         """Test that edits of an event are redacted when the original event
         is redacted.
         """
@@ -1234,7 +1236,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Check the relation is returned
         channel = self.make_request(
@@ -1243,10 +1245,10 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, original_event_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         self.assertIn("chunk", channel.json_body)
-        self.assertEquals(len(channel.json_body["chunk"]), 1)
+        self.assertEqual(len(channel.json_body["chunk"]), 1)
 
         # Redact the original event
         channel = self.make_request(
@@ -1256,7 +1258,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             access_token=self.user_token,
             content="{}",
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Try to check for remaining m.replace relations
         channel = self.make_request(
@@ -1265,13 +1267,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, original_event_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Check that no relations are returned
         self.assertIn("chunk", channel.json_body)
-        self.assertEquals(channel.json_body["chunk"], [])
+        self.assertEqual(channel.json_body["chunk"], [])
 
-    def test_aggregations_redaction_prevents_access_to_aggregations(self):
+    def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None:
         """Test that annotations of an event are redacted when the original event
         is redacted.
         """
@@ -1283,7 +1285,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Redact the original
         channel = self.make_request(
@@ -1297,7 +1299,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             access_token=self.user_token,
             content="{}",
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # Check that aggregations returns zero
         channel = self.make_request(
@@ -1306,15 +1308,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, original_event_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         self.assertIn("chunk", channel.json_body)
-        self.assertEquals(channel.json_body["chunk"], [])
+        self.assertEqual(channel.json_body["chunk"], [])
 
-    def test_unknown_relations(self):
+    def test_unknown_relations(self) -> None:
         """Unknown relations should be accepted."""
         channel = self._send_relation("m.relation.test", "m.room.test")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -1323,18 +1325,18 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, self.parent_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
 
         # We expect to get back a single pagination result, which is the full
         # relation event we sent above.
-        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
         self.assert_dict(
             {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
             channel.json_body["chunk"][0],
         )
 
         # We also expect to get the original event (the id of which is self.parent_id)
-        self.assertEquals(
+        self.assertEqual(
             channel.json_body["original_event"]["event_id"], self.parent_id
         )
 
@@ -1344,7 +1346,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         self.assertNotIn("m.relations", channel.json_body["unsigned"])
 
         # But unknown relations can be directly queried.
@@ -1354,8 +1356,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             % (self.room, self.parent_id),
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
-        self.assertEquals(channel.json_body["chunk"], [])
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(channel.json_body["chunk"], [])
 
     def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
         """
@@ -1419,18 +1421,18 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         return user_id, access_token
 
-    def test_background_update(self):
+    def test_background_update(self) -> None:
         """Test the event_arbitrary_relations background update."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_good = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_bad = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         thread_event_id = channel.json_body["event_id"]
 
         # Clean-up the table as if the inserts did not happen during event creation.
@@ -1450,8 +1452,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
-        self.assertEquals(
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(
             [ev["event_id"] for ev in channel.json_body["chunk"]],
             [annotation_event_id_good],
         )
@@ -1475,7 +1477,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
             access_token=self.user_token,
         )
-        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(200, channel.code, channel.json_body)
         self.assertCountEqual(
             [ev["event_id"] for ev in channel.json_body["chunk"]],
             [annotation_event_id_good, thread_event_id],
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index fe5b536d97..f3bf8d0934 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -13,9 +13,14 @@
 # limitations under the License.
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.visibility import filter_events_for_client
 
 from tests import unittest
@@ -31,7 +36,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["retention"] = {
             "enabled": True,
@@ -47,15 +52,15 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         return self.hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("user", "password")
         self.token = self.login("user", "password")
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.serializer = self.hs.get_event_client_serializer()
         self.clock = self.hs.get_clock()
 
-    def test_retention_event_purged_with_state_event(self):
+    def test_retention_event_purged_with_state_event(self) -> None:
         """Tests that expired events are correctly purged when the room's retention policy
         is defined by a state event.
         """
@@ -72,7 +77,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         self._test_retention_event_purged(room_id, one_day_ms * 1.5)
 
-    def test_retention_event_purged_with_state_event_outside_allowed(self):
+    def test_retention_event_purged_with_state_event_outside_allowed(self) -> None:
         """Tests that the server configuration can override the policy for a room when
         running the purge jobs.
         """
@@ -102,7 +107,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # instead of the one specified in the room's policy.
         self._test_retention_event_purged(room_id, one_day_ms * 0.5)
 
-    def test_retention_event_purged_without_state_event(self):
+    def test_retention_event_purged_without_state_event(self) -> None:
         """Tests that expired events are correctly purged when the room's retention policy
         is defined by the server's configuration's default retention policy.
         """
@@ -110,11 +115,11 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         self._test_retention_event_purged(room_id, one_day_ms * 2)
 
-    def test_visibility(self):
+    def test_visibility(self) -> None:
         """Tests that synapse.visibility.filter_events_for_client correctly filters out
         outdated events
         """
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         storage = self.hs.get_storage()
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
         events = []
@@ -152,7 +157,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # That event should be the second, not outdated event.
         self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
 
-    def _test_retention_event_purged(self, room_id: str, increment: float):
+    def _test_retention_event_purged(self, room_id: str, increment: float) -> None:
         """Run the following test scenario to test the message retention policy support:
 
         1. Send event 1
@@ -186,6 +191,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
 
         expired_event_id = resp.get("event_id")
+        assert expired_event_id is not None
 
         # Check that we can retrieve the event.
         expired_event = self.get_event(expired_event_id)
@@ -201,6 +207,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
 
         valid_event_id = resp.get("event_id")
+        assert valid_event_id is not None
 
         # Advance the time again. Now our first event should have expired but our second
         # one should still be kept.
@@ -218,7 +225,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # has been purged.
         self.get_event(room_id, create_event.event_id)
 
-    def get_event(self, event_id, expect_none=False):
+    def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
         event = self.get_success(self.store.get_event(event_id, allow_none=True))
 
         if expect_none:
@@ -240,7 +247,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["retention"] = {
             "enabled": True,
@@ -254,11 +261,11 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         )
         return self.hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("user", "password")
         self.token = self.login("user", "password")
 
-    def test_no_default_policy(self):
+    def test_no_default_policy(self) -> None:
         """Tests that an event doesn't get expired if there is neither a default retention
         policy nor a policy specific to the room.
         """
@@ -266,7 +273,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
 
         self._test_retention(room_id)
 
-    def test_state_policy(self):
+    def test_state_policy(self) -> None:
         """Tests that an event gets correctly expired if there is no default retention
         policy but there's a policy specific to the room.
         """
@@ -283,12 +290,15 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
 
         self._test_retention(room_id, expected_code_for_first_event=404)
 
-    def _test_retention(self, room_id, expected_code_for_first_event=200):
+    def _test_retention(
+        self, room_id: str, expected_code_for_first_event: int = 200
+    ) -> None:
         # Send a first event to the room. This is the event we'll want to be purged at the
         # end of the test.
         resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
 
         first_event_id = resp.get("event_id")
+        assert first_event_id is not None
 
         # Check that we can retrieve the event.
         expired_event = self.get_event(room_id, first_event_id)
@@ -304,6 +314,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
 
         second_event_id = resp.get("event_id")
+        assert second_event_id is not None
 
         # Advance the time by another month.
         self.reactor.advance(one_day_ms * 30 / 1000)
@@ -322,7 +333,9 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
         second_event = self.get_event(room_id, second_event_id)
         self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event)
 
-    def get_event(self, room_id, event_id, expected_code=200):
+    def get_event(
+        self, room_id: str, event_id: str, expected_code: int = 200
+    ) -> JsonDict:
         url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
 
         channel = self.make_request("GET", url, access_token=self.token)
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index e9f8704035..44f333a0ee 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -134,7 +134,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
         return room_id, event_id_a, event_id_b, event_id_c
 
     @unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
-    def test_same_state_groups_for_whole_historical_batch(self):
+    def test_same_state_groups_for_whole_historical_batch(self) -> None:
         """Make sure that when using the `/batch_send` endpoint to import a
         bunch of historical messages, it re-uses the same `state_group` across
         the whole batch. This is an easy optimization to make sure we're getting
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index b7f086927b..e0b11e7264 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,7 @@ class RoomBase(unittest.HomeserverTestCase):
         async def _insert_client_ip(*args, **kwargs):
             return None
 
-        self.hs.get_datastore().insert_client_ip = _insert_client_ip
+        self.hs.get_datastores().main.insert_client_ip = _insert_client_ip
 
         return self.hs
 
@@ -95,7 +95,7 @@ class RoomPermissionsTestCase(RoomBase):
         channel = self.make_request(
             "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}'
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # set topic for public room
         channel = self.make_request(
@@ -103,7 +103,7 @@ class RoomPermissionsTestCase(RoomBase):
             ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"),
             b'{"topic":"Public Room Topic"}',
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # auth as user_id now
         self.helper.auth_user_id = self.user_id
@@ -125,28 +125,28 @@ class RoomPermissionsTestCase(RoomBase):
             "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
             msg_content,
         )
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # send message in created room not joined (no state), expect 403
         channel = self.make_request("PUT", send_msg_path(), msg_content)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # send message in created room and invited, expect 403
         self.helper.invite(
             room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
         )
         channel = self.make_request("PUT", send_msg_path(), msg_content)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # send message in created room and joined, expect 200
         self.helper.join(room=self.created_rmid, user=self.user_id)
         channel = self.make_request("PUT", send_msg_path(), msg_content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # send message in created room and left, expect 403
         self.helper.leave(room=self.created_rmid, user=self.user_id)
         channel = self.make_request("PUT", send_msg_path(), msg_content)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
     def test_topic_perms(self):
         topic_content = b'{"topic":"My Topic Name"}'
@@ -156,28 +156,28 @@ class RoomPermissionsTestCase(RoomBase):
         channel = self.make_request(
             "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content
         )
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
         channel = self.make_request(
             "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid
         )
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # set/get topic in created PRIVATE room not joined, expect 403
         channel = self.make_request("PUT", topic_path, topic_content)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
         channel = self.make_request("GET", topic_path)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # set topic in created PRIVATE room and invited, expect 403
         self.helper.invite(
             room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id
         )
         channel = self.make_request("PUT", topic_path, topic_content)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # get topic in created PRIVATE room and invited, expect 403
         channel = self.make_request("GET", topic_path)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # set/get topic in created PRIVATE room and joined, expect 200
         self.helper.join(room=self.created_rmid, user=self.user_id)
@@ -185,25 +185,25 @@ class RoomPermissionsTestCase(RoomBase):
         # Only room ops can set topic by default
         self.helper.auth_user_id = self.rmcreator_id
         channel = self.make_request("PUT", topic_path, topic_content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.helper.auth_user_id = self.user_id
 
         channel = self.make_request("GET", topic_path)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body)
 
         # set/get topic in created PRIVATE room and left, expect 403
         self.helper.leave(room=self.created_rmid, user=self.user_id)
         channel = self.make_request("PUT", topic_path, topic_content)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
         channel = self.make_request("GET", topic_path)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # get topic in PUBLIC room, not joined, expect 403
         channel = self.make_request(
             "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid
         )
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         # set topic in PUBLIC room, not joined, expect 403
         channel = self.make_request(
@@ -211,7 +211,7 @@ class RoomPermissionsTestCase(RoomBase):
             "/rooms/%s/state/m.room.topic" % self.created_public_rmid,
             topic_content,
         )
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
     def _test_get_membership(
         self, room=None, members: Iterable = frozenset(), expect_code=None
@@ -219,7 +219,7 @@ class RoomPermissionsTestCase(RoomBase):
         for member in members:
             path = "/rooms/%s/state/m.room.member/%s" % (room, member)
             channel = self.make_request("GET", path)
-            self.assertEquals(expect_code, channel.code)
+            self.assertEqual(expect_code, channel.code)
 
     def test_membership_basic_room_perms(self):
         # === room does not exist ===
@@ -478,16 +478,16 @@ class RoomsMemberListTestCase(RoomBase):
     def test_get_member_list(self):
         room_id = self.helper.create_room_as(self.user_id)
         channel = self.make_request("GET", "/rooms/%s/members" % room_id)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_room(self):
         channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_permission(self):
         room_id = self.helper.create_room_as("@some_other_guy:red")
         channel = self.make_request("GET", "/rooms/%s/members" % room_id)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_permission_with_at_token(self):
         """
@@ -498,7 +498,7 @@ class RoomsMemberListTestCase(RoomBase):
 
         # first sync to get an at token
         channel = self.make_request("GET", "/sync")
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         sync_token = channel.json_body["next_batch"]
 
         # check that permission is denied for @sid1:red to get the
@@ -507,7 +507,7 @@ class RoomsMemberListTestCase(RoomBase):
             "GET",
             f"/rooms/{room_id}/members?at={sync_token}",
         )
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_permission_former_member(self):
         """
@@ -520,14 +520,14 @@ class RoomsMemberListTestCase(RoomBase):
 
         # check that the user can see the member list to start with
         channel = self.make_request("GET", "/rooms/%s/members" % room_id)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # ban the user
         self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban")
 
         # check the user can no longer see the member list
         channel = self.make_request("GET", "/rooms/%s/members" % room_id)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_no_permission_former_member_with_at_token(self):
         """
@@ -541,14 +541,14 @@ class RoomsMemberListTestCase(RoomBase):
 
         # sync to get an at token
         channel = self.make_request("GET", "/sync")
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         sync_token = channel.json_body["next_batch"]
 
         # check that the user can see the member list to start with
         channel = self.make_request(
             "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
         )
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # ban the user (Note: the user is actually allowed to see this event and
         # state so that they know they're banned!)
@@ -560,14 +560,14 @@ class RoomsMemberListTestCase(RoomBase):
 
         # now, with the original user, sync again to get a new at token
         channel = self.make_request("GET", "/sync")
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         sync_token = channel.json_body["next_batch"]
 
         # check the user can no longer see the updated member list
         channel = self.make_request(
             "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token)
         )
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
     def test_get_member_list_mixed_memberships(self):
         room_creator = "@some_other_guy:red"
@@ -576,17 +576,17 @@ class RoomsMemberListTestCase(RoomBase):
         self.helper.invite(room=room_id, src=room_creator, targ=self.user_id)
         # can't see list if you're just invited.
         channel = self.make_request("GET", room_path)
-        self.assertEquals(403, channel.code, msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.result["body"])
 
         self.helper.join(room=room_id, user=self.user_id)
         # can see list now joined
         channel = self.make_request("GET", room_path)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         self.helper.leave(room=room_id, user=self.user_id)
         # can see old list once left
         channel = self.make_request("GET", room_path)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
 
 class RoomsCreateTestCase(RoomBase):
@@ -598,19 +598,19 @@ class RoomsCreateTestCase(RoomBase):
         # POST with no config keys, expect new room id
         channel = self.make_request("POST", "/createRoom", "{}")
 
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_visibility_key(self):
         # POST with visibility config key, expect new room id
         channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_custom_key(self):
         # POST with custom config keys, expect new room id
         channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_known_and_unknown_keys(self):
@@ -618,16 +618,16 @@ class RoomsCreateTestCase(RoomBase):
         channel = self.make_request(
             "POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         self.assertTrue("room_id" in channel.json_body)
 
     def test_post_room_invalid_content(self):
         # POST with invalid content / paths, expect 400
         channel = self.make_request("POST", "/createRoom", b'{"visibili')
-        self.assertEquals(400, channel.code)
+        self.assertEqual(400, channel.code)
 
         channel = self.make_request("POST", "/createRoom", b'["hello"]')
-        self.assertEquals(400, channel.code)
+        self.assertEqual(400, channel.code)
 
     def test_post_room_invitees_invalid_mxid(self):
         # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
@@ -635,7 +635,7 @@ class RoomsCreateTestCase(RoomBase):
         channel = self.make_request(
             "POST", "/createRoom", b'{"invite":["@alice:example.com "]}'
         )
-        self.assertEquals(400, channel.code)
+        self.assertEqual(400, channel.code)
 
     @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
     def test_post_room_invitees_ratelimit(self):
@@ -667,7 +667,7 @@ class RoomsCreateTestCase(RoomBase):
 
         # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
         self.get_success(
-            self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0)
+            self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0)
         )
 
         # Test that the invites aren't ratelimited anymore.
@@ -694,9 +694,9 @@ class RoomsCreateTestCase(RoomBase):
             "/createRoom",
             {},
         )
-        self.assertEquals(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, 200, channel.json_body)
 
-        self.assertEquals(join_mock.call_count, 0)
+        self.assertEqual(join_mock.call_count, 0)
 
 
 class RoomTopicTestCase(RoomBase):
@@ -712,54 +712,54 @@ class RoomTopicTestCase(RoomBase):
     def test_invalid_puts(self):
         # missing keys or invalid json
         channel = self.make_request("PUT", self.path, "{}")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", self.path, '{"_name":"bo"}')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", self.path, '{"nao')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request(
             "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]'
         )
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", self.path, "text only")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", self.path, "")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         # valid key, wrong type
         content = '{"topic":["Topic name"]}'
         channel = self.make_request("PUT", self.path, content)
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_topic(self):
         # nothing should be there
         channel = self.make_request("GET", self.path)
-        self.assertEquals(404, channel.code, msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.result["body"])
 
         # valid put
         content = '{"topic":"Topic name"}'
         channel = self.make_request("PUT", self.path, content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # valid get
         channel = self.make_request("GET", self.path)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(content), channel.json_body)
 
     def test_rooms_topic_with_extra_keys(self):
         # valid put with extra keys
         content = '{"topic":"Seasons","subtopic":"Summer"}'
         channel = self.make_request("PUT", self.path, content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # valid get
         channel = self.make_request("GET", self.path)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assert_dict(json.loads(content), channel.json_body)
 
 
@@ -775,22 +775,22 @@ class RoomMemberStateTestCase(RoomBase):
         path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
         # missing keys or invalid json
         channel = self.make_request("PUT", path, "{}")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, '{"_name":"bo"}')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, '{"nao')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, "text only")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, "")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         # valid keys, wrong types
         content = '{"membership":["%s","%s","%s"]}' % (
@@ -799,7 +799,7 @@ class RoomMemberStateTestCase(RoomBase):
             Membership.LEAVE,
         )
         channel = self.make_request("PUT", path, content.encode("ascii"))
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_members_self(self):
         path = "/rooms/%s/state/m.room.member/%s" % (
@@ -810,13 +810,13 @@ class RoomMemberStateTestCase(RoomBase):
         # valid join message (NOOP since we made the room)
         content = '{"membership":"%s"}' % Membership.JOIN
         channel = self.make_request("PUT", path, content.encode("ascii"))
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("GET", path, None)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         expected_response = {"membership": Membership.JOIN}
-        self.assertEquals(expected_response, channel.json_body)
+        self.assertEqual(expected_response, channel.json_body)
 
     def test_rooms_members_other(self):
         self.other_id = "@zzsid1:red"
@@ -828,11 +828,11 @@ class RoomMemberStateTestCase(RoomBase):
         # valid invite message
         content = '{"membership":"%s"}' % Membership.INVITE
         channel = self.make_request("PUT", path, content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("GET", path, None)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
-        self.assertEquals(json.loads(content), channel.json_body)
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(json.loads(content), channel.json_body)
 
     def test_rooms_members_other_custom_keys(self):
         self.other_id = "@zzsid1:red"
@@ -847,11 +847,11 @@ class RoomMemberStateTestCase(RoomBase):
             "Join us!",
         )
         channel = self.make_request("PUT", path, content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("GET", path, None)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
-        self.assertEquals(json.loads(content), channel.json_body)
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(json.loads(content), channel.json_body)
 
 
 class RoomInviteRatelimitTestCase(RoomBase):
@@ -937,7 +937,7 @@ class RoomJoinTestCase(RoomBase):
                 False,
             ),
         )
-        self.assertEquals(
+        self.assertEqual(
             callback_mock.call_args,
             expected_call_args,
             callback_mock.call_args,
@@ -955,7 +955,7 @@ class RoomJoinTestCase(RoomBase):
                 True,
             ),
         )
-        self.assertEquals(
+        self.assertEqual(
             callback_mock.call_args,
             expected_call_args,
             callback_mock.call_args,
@@ -1013,7 +1013,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
         # Update the display name for the user.
         path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
         channel = self.make_request("PUT", path, {"displayname": "John Doe"})
-        self.assertEquals(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.code, 200, channel.json_body)
 
         # Check that all the rooms have been sent a profile update into.
         for room_id in room_ids:
@@ -1023,10 +1023,10 @@ class RoomJoinRatelimitTestCase(RoomBase):
             )
 
             channel = self.make_request("GET", path)
-            self.assertEquals(channel.code, 200)
+            self.assertEqual(channel.code, 200)
 
             self.assertIn("displayname", channel.json_body)
-            self.assertEquals(channel.json_body["displayname"], "John Doe")
+            self.assertEqual(channel.json_body["displayname"], "John Doe")
 
     @unittest.override_config(
         {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
@@ -1047,7 +1047,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
             # if all of these requests ended up joining the user to a room.
             for _ in range(4):
                 channel = self.make_request("POST", path % room_id, {})
-                self.assertEquals(channel.code, 200)
+                self.assertEqual(channel.code, 200)
 
     @unittest.override_config(
         {
@@ -1060,7 +1060,9 @@ class RoomJoinRatelimitTestCase(RoomBase):
         user_id = self.register_user("testuser", "password")
 
         # Check that the new user successfully joined the four rooms
-        rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+        rooms = self.get_success(
+            self.hs.get_datastores().main.get_rooms_for_user(user_id)
+        )
         self.assertEqual(len(rooms), 4)
 
 
@@ -1076,40 +1078,40 @@ class RoomMessagesTestCase(RoomBase):
         path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
         # missing keys or invalid json
         channel = self.make_request("PUT", path, b"{}")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, b'{"_name":"bo"}')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, b'{"nao')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]')
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, b"text only")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         channel = self.make_request("PUT", path, b"")
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
     def test_rooms_messages_sent(self):
         path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
 
         content = b'{"body":"test","msgtype":{"type":"a"}}'
         channel = self.make_request("PUT", path, content)
-        self.assertEquals(400, channel.code, msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
 
         # custom message types
         content = b'{"body":"test","msgtype":"test.custom.text"}'
         channel = self.make_request("PUT", path, content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # m.text message type
         path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id))
         content = b'{"body":"test2","msgtype":"m.text"}'
         channel = self.make_request("PUT", path, content)
-        self.assertEquals(200, channel.code, msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
 
 class RoomInitialSyncTestCase(RoomBase):
@@ -1123,10 +1125,10 @@ class RoomInitialSyncTestCase(RoomBase):
 
     def test_initial_sync(self):
         channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
-        self.assertEquals(self.room_id, channel.json_body["room_id"])
-        self.assertEquals("join", channel.json_body["membership"])
+        self.assertEqual(self.room_id, channel.json_body["room_id"])
+        self.assertEqual("join", channel.json_body["membership"])
 
         # Room state is easier to assert on if we unpack it into a dict
         state = {}
@@ -1150,7 +1152,7 @@ class RoomInitialSyncTestCase(RoomBase):
             e["content"]["user_id"]: e for e in channel.json_body["presence"]
         }
         self.assertTrue(self.user_id in presence_by_user)
-        self.assertEquals("m.presence", presence_by_user[self.user_id]["type"])
+        self.assertEqual("m.presence", presence_by_user[self.user_id]["type"])
 
 
 class RoomMessageListTestCase(RoomBase):
@@ -1166,9 +1168,9 @@ class RoomMessageListTestCase(RoomBase):
         channel = self.make_request(
             "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         self.assertTrue("start" in channel.json_body)
-        self.assertEquals(token, channel.json_body["start"])
+        self.assertEqual(token, channel.json_body["start"])
         self.assertTrue("chunk" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
@@ -1177,14 +1179,14 @@ class RoomMessageListTestCase(RoomBase):
         channel = self.make_request(
             "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         self.assertTrue("start" in channel.json_body)
-        self.assertEquals(token, channel.json_body["start"])
+        self.assertEqual(token, channel.json_body["start"])
         self.assertTrue("chunk" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
     def test_room_messages_purge(self):
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         pagination_handler = self.hs.get_pagination_handler()
 
         # Send a first message in the room, which will be removed by the purge.
@@ -2612,7 +2614,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
             },
             access_token=self.tok,
         )
-        self.assertEquals(channel.code, 200)
+        self.assertEqual(channel.code, 200)
 
         # Check that the callback was called with the right params.
         mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
@@ -2634,7 +2636,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
             },
             access_token=self.tok,
         )
-        self.assertEquals(channel.code, 403)
+        self.assertEqual(channel.code, 403)
 
         # Also check that it stopped before calling _make_and_store_3pid_invite.
         make_invite_mock.assert_called_once()
diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index e2ed14457f..c3942889e1 100644
--- a/tests/rest/client/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
@@ -26,7 +26,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def test_user_to_user(self):
+    def test_user_to_user(self) -> None:
         """A to-device message from one user to another should get delivered"""
 
         user1 = self.register_user("u1", "pass")
@@ -73,7 +73,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
         self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
 
     @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
-    def test_local_room_key_request(self):
+    def test_local_room_key_request(self) -> None:
         """m.room_key_request has special-casing; test from local user"""
         user1 = self.register_user("u1", "pass")
         user1_tok = self.login("u1", "pass", "d1")
@@ -128,7 +128,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
         )
 
     @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}})
-    def test_remote_room_key_request(self):
+    def test_remote_room_key_request(self) -> None:
         """m.room_key_request has special-casing; test from remote user"""
         user2 = self.register_user("u2", "pass")
         user2_tok = self.login("u2", "pass", "d2")
@@ -199,7 +199,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
             },
         )
 
-    def test_limited_sync(self):
+    def test_limited_sync(self) -> None:
         """If a limited sync for to-devices happens the next /sync should respond immediately."""
 
         self.register_user("u1", "pass")
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index b0c44af033..ae5ada3be7 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -14,6 +14,8 @@
 
 from unittest.mock import Mock, patch
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.rest.client import (
@@ -23,18 +25,20 @@ from synapse.rest.client import (
     room,
     room_upgrade_rest_servlet,
 )
+from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util import Clock
 
 from tests import unittest
 
 
 class _ShadowBannedBase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Create two users, one of which is shadow-banned.
         self.banned_user_id = self.register_user("banned", "test")
         self.banned_access_token = self.login("banned", "test")
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         self.get_success(
             self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
@@ -55,7 +59,7 @@ class RoomTestCase(_ShadowBannedBase):
         room_upgrade_rest_servlet.register_servlets,
     ]
 
-    def test_invite(self):
+    def test_invite(self) -> None:
         """Invites from shadow-banned users don't actually get sent."""
 
         # The create works fine.
@@ -77,7 +81,7 @@ class RoomTestCase(_ShadowBannedBase):
         )
         self.assertEqual(invited_rooms, [])
 
-    def test_invite_3pid(self):
+    def test_invite_3pid(self) -> None:
         """Ensure that a 3PID invite does not attempt to contact the identity server."""
         identity_handler = self.hs.get_identity_handler()
         identity_handler.lookup_3pid = Mock(
@@ -96,12 +100,12 @@ class RoomTestCase(_ShadowBannedBase):
             {"id_server": "test", "medium": "email", "address": "test@test.test"},
             access_token=self.banned_access_token,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # This should have raised an error earlier, but double check this wasn't called.
         identity_handler.lookup_3pid.assert_not_called()
 
-    def test_create_room(self):
+    def test_create_room(self) -> None:
         """Invitations during a room creation should be discarded, but the room still gets created."""
         # The room creation is successful.
         channel = self.make_request(
@@ -110,7 +114,7 @@ class RoomTestCase(_ShadowBannedBase):
             {"visibility": "public", "invite": [self.other_user_id]},
             access_token=self.banned_access_token,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         room_id = channel.json_body["room_id"]
 
         # But the user wasn't actually invited.
@@ -126,7 +130,7 @@ class RoomTestCase(_ShadowBannedBase):
         users = self.get_success(self.store.get_users_in_room(room_id))
         self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
 
-    def test_message(self):
+    def test_message(self) -> None:
         """Messages from shadow-banned users don't actually get sent."""
 
         room_id = self.helper.create_room_as(
@@ -151,7 +155,7 @@ class RoomTestCase(_ShadowBannedBase):
         )
         self.assertNotIn(event_id, latest_events)
 
-    def test_upgrade(self):
+    def test_upgrade(self) -> None:
         """A room upgrade should fail, but look like it succeeded."""
 
         # The create works fine.
@@ -165,7 +169,7 @@ class RoomTestCase(_ShadowBannedBase):
             {"new_version": "6"},
             access_token=self.banned_access_token,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         # A new room_id should be returned.
         self.assertIn("replacement_room", channel.json_body)
 
@@ -177,7 +181,7 @@ class RoomTestCase(_ShadowBannedBase):
         # The summary should be empty since the room doesn't exist.
         self.assertEqual(summary, {})
 
-    def test_typing(self):
+    def test_typing(self) -> None:
         """Typing notifications should not be propagated into the room."""
         # The create works fine.
         room_id = self.helper.create_room_as(
@@ -190,11 +194,11 @@ class RoomTestCase(_ShadowBannedBase):
             {"typing": True, "timeout": 30000},
             access_token=self.banned_access_token,
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
         # There should be no typing events.
         event_source = self.hs.get_event_sources().sources.typing
-        self.assertEquals(event_source.get_current_key(), 0)
+        self.assertEqual(event_source.get_current_key(), 0)
 
         # The other user can join and send typing events.
         self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
@@ -205,10 +209,10 @@ class RoomTestCase(_ShadowBannedBase):
             {"typing": True, "timeout": 30000},
             access_token=self.other_access_token,
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
         # These appear in the room.
-        self.assertEquals(event_source.get_current_key(), 1)
+        self.assertEqual(event_source.get_current_key(), 1)
         events = self.get_success(
             event_source.get_new_events(
                 user=UserID.from_string(self.other_user_id),
@@ -218,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
                 is_guest=False,
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [
                 {
@@ -240,7 +244,7 @@ class ProfileTestCase(_ShadowBannedBase):
         room.register_servlets,
     ]
 
-    def test_displayname(self):
+    def test_displayname(self) -> None:
         """Profile changes should succeed, but don't end up in a room."""
         original_display_name = "banned"
         new_display_name = "new name"
@@ -257,7 +261,7 @@ class ProfileTestCase(_ShadowBannedBase):
             {"displayname": new_display_name},
             access_token=self.banned_access_token,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(channel.json_body, {})
 
         # The user's display name should be updated.
@@ -281,7 +285,7 @@ class ProfileTestCase(_ShadowBannedBase):
             event.content, {"membership": "join", "displayname": original_display_name}
         )
 
-    def test_room_displayname(self):
+    def test_room_displayname(self) -> None:
         """Changes to state events for a room should be processed, but not end up in the room."""
         original_display_name = "banned"
         new_display_name = "new name"
@@ -299,7 +303,7 @@ class ProfileTestCase(_ShadowBannedBase):
             {"membership": "join", "displayname": new_display_name},
             access_token=self.banned_access_token,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertIn("event_id", channel.json_body)
 
         # The display name in the room should not be changed.
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
index 283eccd53f..3818b7b14b 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_shared_rooms.py
@@ -11,8 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.rest.client import login, room, shared_rooms
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -30,16 +34,16 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         shared_rooms.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["update_user_directory"] = True
         return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
         self.handler = hs.get_user_directory_handler()
 
-    def _get_shared_rooms(self, token, other_user) -> FakeChannel:
+    def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
         return self.make_request(
             "GET",
             "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
@@ -47,14 +51,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
             access_token=token,
         )
 
-    def test_shared_room_list_public(self):
+    def test_shared_room_list_public(self) -> None:
         """
         A room should show up in the shared list of rooms between two users
         if it is public.
         """
         self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
 
-    def test_shared_room_list_private(self):
+    def test_shared_room_list_private(self) -> None:
         """
         A room should show up in the shared list of rooms between two users
         if it is private.
@@ -63,7 +67,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
             room_one_is_public=False, room_two_is_public=False
         )
 
-    def test_shared_room_list_mixed(self):
+    def test_shared_room_list_mixed(self) -> None:
         """
         The shared room list between two users should contain both public and private
         rooms.
@@ -72,7 +76,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
 
     def _check_shared_rooms_with(
         self, room_one_is_public: bool, room_two_is_public: bool
-    ):
+    ) -> None:
         """Checks that shared public or private rooms between two users appear in
         their shared room lists
         """
@@ -91,9 +95,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         # Check shared rooms from user1's perspective.
         # We should see the one room in common
         channel = self._get_shared_rooms(u1_token, u2)
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(len(channel.json_body["joined"]), 1)
-        self.assertEquals(channel.json_body["joined"][0], room_id_one)
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(len(channel.json_body["joined"]), 1)
+        self.assertEqual(channel.json_body["joined"][0], room_id_one)
 
         # Create another room and invite user2 to it
         room_id_two = self.helper.create_room_as(
@@ -104,12 +108,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
 
         # Check shared rooms again. We should now see both rooms.
         channel = self._get_shared_rooms(u1_token, u2)
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(len(channel.json_body["joined"]), 2)
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(len(channel.json_body["joined"]), 2)
         for room_id_id in channel.json_body["joined"]:
             self.assertIn(room_id_id, [room_id_one, room_id_two])
 
-    def test_shared_room_list_after_leave(self):
+    def test_shared_room_list_after_leave(self) -> None:
         """
         A room should no longer be considered shared if the other
         user has left it.
@@ -125,18 +129,18 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
 
         # Assert user directory is not empty
         channel = self._get_shared_rooms(u1_token, u2)
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(len(channel.json_body["joined"]), 1)
-        self.assertEquals(channel.json_body["joined"][0], room)
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(len(channel.json_body["joined"]), 1)
+        self.assertEqual(channel.json_body["joined"][0], room)
 
         self.helper.leave(room, user=u1, tok=u1_token)
 
         # Check user1's view of shared rooms with user2
         channel = self._get_shared_rooms(u1_token, u2)
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(len(channel.json_body["joined"]), 0)
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(len(channel.json_body["joined"]), 0)
 
         # Check user2's view of shared rooms with user1
         channel = self._get_shared_rooms(u2_token, u1)
-        self.assertEquals(200, channel.code, channel.result)
-        self.assertEquals(len(channel.json_body["joined"]), 0)
+        self.assertEqual(200, channel.code, channel.result)
+        self.assertEqual(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index cd4af2b1f3..4351013952 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -13,9 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from typing import List, Optional
 
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import (
     EventContentFields,
@@ -24,6 +27,9 @@ from synapse.api.constants import (
     RelationTypes,
 )
 from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.federation.transport.test_knocking import (
@@ -43,7 +49,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def test_sync_argless(self):
+    def test_sync_argless(self) -> None:
         channel = self.make_request("GET", "/sync")
 
         self.assertEqual(channel.code, 200)
@@ -58,7 +64,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def test_sync_filter_labels(self):
+    def test_sync_filter_labels(self) -> None:
         """Test that we can filter by a label."""
         sync_filter = json.dumps(
             {
@@ -77,7 +83,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
         self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
 
-    def test_sync_filter_not_labels(self):
+    def test_sync_filter_not_labels(self) -> None:
         """Test that we can filter by the absence of a label."""
         sync_filter = json.dumps(
             {
@@ -99,7 +105,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             events[2]["content"]["body"], "with two wrong labels", events[2]
         )
 
-    def test_sync_filter_labels_not_labels(self):
+    def test_sync_filter_labels_not_labels(self) -> None:
         """Test that we can filter by both a label and the absence of another label."""
         sync_filter = json.dumps(
             {
@@ -118,7 +124,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(events), 1, [event["content"] for event in events])
         self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
 
-    def _test_sync_filter_labels(self, sync_filter):
+    def _test_sync_filter_labels(self, sync_filter: str) -> List[JsonDict]:
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
 
@@ -194,7 +200,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
     user_id = True
     hijack_auth = False
 
-    def test_sync_backwards_typing(self):
+    def test_sync_backwards_typing(self) -> None:
         """
         If the typing serial goes backwards and the typing handler is then reset
         (such as when the master restarts and sets the typing serial to 0), we
@@ -231,10 +237,10 @@ class SyncTypingTests(unittest.HomeserverTestCase):
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": true, "timeout": 30000}',
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
         channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,))
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
         # Stop typing.
@@ -243,7 +249,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": false}',
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
         # Start typing.
         channel = self.make_request(
@@ -251,11 +257,11 @@ class SyncTypingTests(unittest.HomeserverTestCase):
             typing_url % (room, other_user_id, other_access_token),
             b'{"typing": true, "timeout": 30000}',
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
         # Should return immediately
         channel = self.make_request("GET", sync_url % (access_token, next_batch))
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
         # Reset typing serial back to 0, as if the master had.
@@ -267,7 +273,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         self.helper.send(room, body="There!", tok=other_access_token)
 
         channel = self.make_request("GET", sync_url % (access_token, next_batch))
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
         # This should time out! But it does not, because our stream token is
@@ -275,7 +281,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
         # already seen) is new, since it's got a token above our new, now-reset
         # stream token.
         channel = self.make_request("GET", sync_url % (access_token, next_batch))
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
         next_batch = channel.json_body["next_batch"]
 
         # Clear the typing information, so that it doesn't think everything is
@@ -298,8 +304,8 @@ class SyncKnockTestCase(
         knock.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
 
@@ -336,7 +342,7 @@ class SyncKnockTestCase(
         )
 
     @override_config({"experimental_features": {"msc2403_enabled": True}})
-    def test_knock_room_state(self):
+    def test_knock_room_state(self) -> None:
         """Tests that /sync returns state from a room after knocking on it."""
         # Knock on a room
         channel = self.make_request(
@@ -345,7 +351,7 @@ class SyncKnockTestCase(
             b"{}",
             self.knocker_tok,
         )
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         # We expect to see the knock event in the stripped room state later
         self.expected_room_state[EventTypes.Member] = {
@@ -383,7 +389,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
 
@@ -402,7 +408,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
 
     @override_config({"experimental_features": {"msc2285_enabled": True}})
-    def test_hidden_read_receipts(self):
+    def test_hidden_read_receipts(self) -> None:
         # Send a message as the first user
         res = self.helper.send(self.room_id, body="hello", tok=self.tok)
 
@@ -441,8 +447,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         ]
     )
     def test_read_receipt_with_empty_body(
-        self, name, user_agent: str, expected_status_code: int
-    ):
+        self, name: str, user_agent: str, expected_status_code: int
+    ) -> None:
         # Send a message as the first user
         res = self.helper.send(self.room_id, body="hello", tok=self.tok)
 
@@ -455,11 +461,11 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, expected_status_code)
 
-    def _get_read_receipt(self):
+    def _get_read_receipt(self) -> Optional[JsonDict]:
         """Syncs and returns the read receipt."""
 
         # Checks if event is a read receipt
-        def is_read_receipt(event):
+        def is_read_receipt(event: JsonDict) -> bool:
             return event["type"] == "m.receipt"
 
         # Sync
@@ -477,7 +483,8 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
             "ephemeral"
         ]["events"]
-        return next(filter(is_read_receipt, ephemeral_events), None)
+        receipt_event = filter(is_read_receipt, ephemeral_events)
+        return next(receipt_event, None)
 
 
 class UnreadMessagesTestCase(unittest.HomeserverTestCase):
@@ -490,7 +497,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         receipts.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
 
@@ -533,7 +540,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
             tok=self.tok,
         )
 
-    def test_unread_counts(self):
+    def test_unread_counts(self) -> None:
         """Tests that /sync returns the right value for the unread count (MSC2654)."""
 
         # Check that our own messages don't increase the unread count.
@@ -640,7 +647,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         )
         self._check_unread_count(5)
 
-    def _check_unread_count(self, expected_count: int):
+    def _check_unread_count(self, expected_count: int) -> None:
         """Syncs and compares the unread count with the expected value."""
 
         channel = self.make_request(
@@ -669,7 +676,7 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def test_noop_sync_does_not_tightloop(self):
+    def test_noop_sync_does_not_tightloop(self) -> None:
         """If the sync times out, we shouldn't cache the result
 
         Essentially a regression test for #8518.
@@ -720,7 +727,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         devices.register_servlets,
     ]
 
-    def test_user_with_no_rooms_receives_self_device_list_updates(self):
+    def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
         """Tests that a user with no rooms still receives their own device list updates"""
         device_id = "TESTDEVICE"
 
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index ac6b86ff6b..bfc04785b7 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -15,12 +15,12 @@ import threading
 from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from unittest.mock import Mock
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, LoginType, Membership
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.rest import admin
-from synapse.rest.client import login, room
+from synapse.rest.client import account, login, profile, room
 from synapse.types import JsonDict, Requester, StateMap
 from synapse.util.frozenutils import unfreeze
 
@@ -80,6 +80,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         admin.register_servlets,
         login.register_servlets,
         room.register_servlets,
+        profile.register_servlets,
+        account.register_servlets,
     ]
 
     def make_homeserver(self, reactor, clock):
@@ -139,7 +141,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             {},
             access_token=self.tok,
         )
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         callback.assert_called_once()
 
@@ -157,7 +159,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             {},
             access_token=self.tok,
         )
-        self.assertEquals(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.result["code"], b"403", channel.result)
 
     def test_third_party_rules_workaround_synapse_errors_pass_through(self):
         """
@@ -193,7 +195,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             access_token=self.tok,
         )
         # Check the error code
-        self.assertEquals(channel.result["code"], b"429", channel.result)
+        self.assertEqual(channel.result["code"], b"429", channel.result)
         # Check the JSON body has had the `nasty` key injected
         self.assertEqual(
             channel.json_body,
@@ -329,10 +331,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             self.hs.get_module_api().create_and_send_event_into_room(event_dict)
         )
 
-        self.assertEquals(event.sender, self.user_id)
-        self.assertEquals(event.room_id, self.room_id)
-        self.assertEquals(event.type, "m.room.message")
-        self.assertEquals(event.content, content)
+        self.assertEqual(event.sender, self.user_id)
+        self.assertEqual(event.room_id, self.room_id)
+        self.assertEqual(event.type, "m.room.message")
+        self.assertEqual(event.content, content)
 
     @unittest.override_config(
         {
@@ -530,3 +532,216 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             },
             tok=self.tok,
         )
+
+    def test_on_profile_update(self):
+        """Tests that the on_profile_update module callback is correctly called on
+        profile updates.
+        """
+        displayname = "Foo"
+        avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
+
+        # Register a mock callback.
+        m = Mock(return_value=make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+
+        # Change the display name.
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/profile/%s/displayname" % self.user_id,
+            {"displayname": displayname},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the callback has been called once for our user.
+        m.assert_called_once()
+        args = m.call_args[0]
+        self.assertEqual(args[0], self.user_id)
+
+        # Test that by_admin is False.
+        self.assertFalse(args[2])
+        # Test that deactivation is False.
+        self.assertFalse(args[3])
+
+        # Check that we've got the right profile data.
+        profile_info = args[1]
+        self.assertEqual(profile_info.display_name, displayname)
+        self.assertIsNone(profile_info.avatar_url)
+
+        # Change the avatar.
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/profile/%s/avatar_url" % self.user_id,
+            {"avatar_url": avatar_url},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the callback has been called once for our user.
+        self.assertEqual(m.call_count, 2)
+        args = m.call_args[0]
+        self.assertEqual(args[0], self.user_id)
+
+        # Test that by_admin is False.
+        self.assertFalse(args[2])
+        # Test that deactivation is False.
+        self.assertFalse(args[3])
+
+        # Check that we've got the right profile data.
+        profile_info = args[1]
+        self.assertEqual(profile_info.display_name, displayname)
+        self.assertEqual(profile_info.avatar_url, avatar_url)
+
+    def test_on_profile_update_admin(self):
+        """Tests that the on_profile_update module callback is correctly called on
+        profile updates triggered by a server admin.
+        """
+        displayname = "Foo"
+        avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
+
+        # Register a mock callback.
+        m = Mock(return_value=make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Change a user's profile.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % self.user_id,
+            {"displayname": displayname, "avatar_url": avatar_url},
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the callback has been called twice (since we update the display name
+        # and avatar separately).
+        self.assertEqual(m.call_count, 2)
+
+        # Get the arguments for the last call and check it's about the right user.
+        args = m.call_args[0]
+        self.assertEqual(args[0], self.user_id)
+
+        # Check that by_admin is True.
+        self.assertTrue(args[2])
+        # Test that deactivation is False.
+        self.assertFalse(args[3])
+
+        # Check that we've got the right profile data.
+        profile_info = args[1]
+        self.assertEqual(profile_info.display_name, displayname)
+        self.assertEqual(profile_info.avatar_url, avatar_url)
+
+    def test_on_user_deactivation_status_changed(self):
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation.
+        """
+        # Register a mocked callback.
+        deactivation_mock = Mock(return_value=make_awaitable(None))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._on_user_deactivation_status_changed_callbacks.append(
+            deactivation_mock,
+        )
+        # Also register a mocked callback for profile updates, to check that the
+        # deactivation code calls it in a way that let modules know the user is being
+        # deactivated.
+        profile_mock = Mock(return_value=make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(
+            profile_mock,
+        )
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+        tok = self.login("altan", "password")
+
+        # Deactivate that user.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/account/deactivate",
+            {
+                "auth": {
+                    "type": LoginType.PASSWORD,
+                    "password": "password",
+                    "identifier": {
+                        "type": "m.id.user",
+                        "user": user_id,
+                    },
+                },
+                "erase": True,
+            },
+            access_token=tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the mock was called once.
+        deactivation_mock.assert_called_once()
+        args = deactivation_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID, and with a True
+        # deactivated flag and a False by_admin flag.
+        self.assertEqual(args[0], user_id)
+        self.assertTrue(args[1])
+        self.assertFalse(args[2])
+
+        # Check that the profile update callback was called twice (once for the display
+        # name and once for the avatar URL), and that the "deactivation" boolean is true.
+        self.assertEqual(profile_mock.call_count, 2)
+        args = profile_mock.call_args[0]
+        self.assertTrue(args[3])
+
+    def test_on_user_deactivation_status_changed_admin(self):
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation triggered by a server admin as
+        well as a reactivation.
+        """
+        # Register a mock callback.
+        m = Mock(return_value=make_awaitable(None))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+
+        # Deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {"deactivated": True},
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the mock was called once.
+        m.assert_called_once()
+        args = m.call_args[0]
+
+        # Check that the mock was called with the right user ID, and with True deactivated
+        # and by_admin flags.
+        self.assertEqual(args[0], user_id)
+        self.assertTrue(args[1])
+        self.assertTrue(args[2])
+
+        # Reactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {"deactivated": False, "password": "hackme"},
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the mock was called once.
+        self.assertEqual(m.call_count, 2)
+        args = m.call_args[0]
+
+        # Check that the mock was called with the right user ID, and with a False
+        # deactivated flag and a True by_admin flag.
+        self.assertEqual(args[0], user_id)
+        self.assertFalse(args[1])
+        self.assertTrue(args[2])
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index ee0abd5295..8b2da88e8a 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -57,7 +57,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         async def _insert_client_ip(*args, **kwargs):
             return None
 
-        hs.get_datastore().insert_client_ip = _insert_client_ip
+        hs.get_datastores().main.insert_client_ip = _insert_client_ip
 
         return hs
 
@@ -72,9 +72,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
-        self.assertEquals(self.event_source.get_current_key(), 1)
+        self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
                 user=UserID.from_string(self.user_id),
@@ -84,7 +84,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
                 is_guest=False,
             )
         )
-        self.assertEquals(
+        self.assertEqual(
             events[0],
             [
                 {
@@ -101,7 +101,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": false}',
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
     def test_typing_timeout(self):
         channel = self.make_request(
@@ -109,19 +109,19 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
-        self.assertEquals(self.event_source.get_current_key(), 1)
+        self.assertEqual(self.event_source.get_current_key(), 1)
 
         self.reactor.advance(36)
 
-        self.assertEquals(self.event_source.get_current_key(), 2)
+        self.assertEqual(self.event_source.get_current_key(), 2)
 
         channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
             b'{"typing": true, "timeout": 30000}',
         )
-        self.assertEquals(200, channel.code)
+        self.assertEqual(200, channel.code)
 
-        self.assertEquals(self.event_source.get_current_key(), 3)
+        self.assertEqual(self.event_source.get_current_key(), 3)
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index a42388b26f..b7d0f42daf 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -13,11 +13,14 @@
 # limitations under the License.
 from typing import Optional
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventContentFields, EventTypes, RoomTypes
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.rest import admin
 from synapse.rest.client import login, room, room_upgrade_rest_servlet
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -31,8 +34,8 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
         room_upgrade_rest_servlet.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
 
         self.creator = self.register_user("creator", "pass")
         self.creator_token = self.login(self.creator, "pass")
@@ -60,15 +63,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
             access_token=token or self.creator_token,
         )
 
-    def test_upgrade(self):
+    def test_upgrade(self) -> None:
         """
         Upgrading a room should work fine.
         """
         channel = self._upgrade_room()
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertIn("replacement_room", channel.json_body)
 
-    def test_not_in_room(self):
+    def test_not_in_room(self) -> None:
         """
         Upgrading a room should work fine.
         """
@@ -77,15 +80,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
         roomless_token = self.login(roomless, "pass")
 
         channel = self._upgrade_room(roomless_token)
-        self.assertEquals(403, channel.code, channel.result)
+        self.assertEqual(403, channel.code, channel.result)
 
-    def test_power_levels(self):
+    def test_power_levels(self) -> None:
         """
         Another user can upgrade the room if their power level is increased.
         """
         # The other user doesn't have the proper power level.
         channel = self._upgrade_room(self.other_token)
-        self.assertEquals(403, channel.code, channel.result)
+        self.assertEqual(403, channel.code, channel.result)
 
         # Increase the power levels so that this user can upgrade.
         power_levels = self.helper.get_state(
@@ -103,15 +106,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
 
         # The upgrade should succeed!
         channel = self._upgrade_room(self.other_token)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
-    def test_power_levels_user_default(self):
+    def test_power_levels_user_default(self) -> None:
         """
         Another user can upgrade the room if the default power level for users is increased.
         """
         # The other user doesn't have the proper power level.
         channel = self._upgrade_room(self.other_token)
-        self.assertEquals(403, channel.code, channel.result)
+        self.assertEqual(403, channel.code, channel.result)
 
         # Increase the power levels so that this user can upgrade.
         power_levels = self.helper.get_state(
@@ -129,15 +132,15 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
 
         # The upgrade should succeed!
         channel = self._upgrade_room(self.other_token)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
-    def test_power_levels_tombstone(self):
+    def test_power_levels_tombstone(self) -> None:
         """
         Another user can upgrade the room if they can send the tombstone event.
         """
         # The other user doesn't have the proper power level.
         channel = self._upgrade_room(self.other_token)
-        self.assertEquals(403, channel.code, channel.result)
+        self.assertEqual(403, channel.code, channel.result)
 
         # Increase the power levels so that this user can upgrade.
         power_levels = self.helper.get_state(
@@ -155,7 +158,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
 
         # The upgrade should succeed!
         channel = self._upgrade_room(self.other_token)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
 
         power_levels = self.helper.get_state(
             self.room_id,
@@ -164,7 +167,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
         )
         self.assertNotIn(self.other, power_levels["users"])
 
-    def test_space(self):
+    def test_space(self) -> None:
         """Test upgrading a space."""
 
         # Create a space.
@@ -197,7 +200,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
 
         # Upgrade the room!
         channel = self._upgrade_room(room_id=space_id)
-        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(200, channel.code, channel.result)
         self.assertIn("replacement_room", channel.json_body)
 
         new_space_id = channel.json_body["replacement_room"]
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 2b3fdadffa..28663826fc 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,6 +19,7 @@ import json
 import re
 import time
 import urllib.parse
+from http import HTTPStatus
 from typing import (
     Any,
     AnyStr,
@@ -40,6 +41,7 @@ from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests.server import FakeChannel, FakeSite, make_request
@@ -47,15 +49,15 @@ from tests.test_utils import FakeResponse
 from tests.test_utils.html_parsers import TestHtmlParser
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class RestHelper:
     """Contains extra helper functions to quickly and clearly perform a given
     REST action, which isn't the focus of the test.
     """
 
-    hs = attr.ib()
-    site = attr.ib(type=Site)
-    auth_user_id = attr.ib()
+    hs: HomeServer
+    site: Site
+    auth_user_id: Optional[str]
 
     @overload
     def create_room_as(
@@ -89,7 +91,7 @@ class RestHelper:
         is_public: Optional[bool] = None,
         room_version: Optional[str] = None,
         tok: Optional[str] = None,
-        expect_code: int = 200,
+        expect_code: int = HTTPStatus.OK,
         extra_content: Optional[Dict] = None,
         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ) -> Optional[str]:
@@ -137,12 +139,19 @@ class RestHelper:
         assert channel.result["code"] == b"%d" % expect_code, channel.result
         self.auth_user_id = temp_id
 
-        if expect_code == 200:
+        if expect_code == HTTPStatus.OK:
             return channel.json_body["room_id"]
         else:
             return None
 
-    def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
+    def invite(
+        self,
+        room: str,
+        src: Optional[str] = None,
+        targ: Optional[str] = None,
+        expect_code: int = HTTPStatus.OK,
+        tok: Optional[str] = None,
+    ) -> None:
         self.change_membership(
             room=room,
             src=src,
@@ -156,7 +165,7 @@ class RestHelper:
         self,
         room: str,
         user: Optional[str] = None,
-        expect_code: int = 200,
+        expect_code: int = HTTPStatus.OK,
         tok: Optional[str] = None,
         appservice_user_id: Optional[str] = None,
     ) -> None:
@@ -170,7 +179,14 @@ class RestHelper:
             expect_code=expect_code,
         )
 
-    def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None):
+    def knock(
+        self,
+        room: Optional[str] = None,
+        user: Optional[str] = None,
+        reason: Optional[str] = None,
+        expect_code: int = HTTPStatus.OK,
+        tok: Optional[str] = None,
+    ) -> None:
         temp_id = self.auth_user_id
         self.auth_user_id = user
         path = "/knock/%s" % room
@@ -199,7 +215,13 @@ class RestHelper:
 
         self.auth_user_id = temp_id
 
-    def leave(self, room=None, user=None, expect_code=200, tok=None):
+    def leave(
+        self,
+        room: str,
+        user: Optional[str] = None,
+        expect_code: int = HTTPStatus.OK,
+        tok: Optional[str] = None,
+    ) -> None:
         self.change_membership(
             room=room,
             src=user,
@@ -209,14 +231,22 @@ class RestHelper:
             expect_code=expect_code,
         )
 
-    def ban(self, room: str, src: str, targ: str, **kwargs: object):
+    def ban(
+        self,
+        room: str,
+        src: str,
+        targ: str,
+        expect_code: int = HTTPStatus.OK,
+        tok: Optional[str] = None,
+    ) -> None:
         """A convenience helper: `change_membership` with `membership` preset to "ban"."""
         self.change_membership(
             room=room,
             src=src,
             targ=targ,
+            tok=tok,
             membership=Membership.BAN,
-            **kwargs,
+            expect_code=expect_code,
         )
 
     def change_membership(
@@ -228,7 +258,7 @@ class RestHelper:
         extra_data: Optional[dict] = None,
         tok: Optional[str] = None,
         appservice_user_id: Optional[str] = None,
-        expect_code: int = 200,
+        expect_code: int = HTTPStatus.OK,
         expect_errcode: Optional[str] = None,
     ) -> None:
         """
@@ -294,13 +324,13 @@ class RestHelper:
 
     def send(
         self,
-        room_id,
-        body=None,
-        txn_id=None,
-        tok=None,
-        expect_code=200,
+        room_id: str,
+        body: Optional[str] = None,
+        txn_id: Optional[str] = None,
+        tok: Optional[str] = None,
+        expect_code: int = HTTPStatus.OK,
         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-    ):
+    ) -> JsonDict:
         if body is None:
             body = "body_text_here"
 
@@ -318,14 +348,14 @@ class RestHelper:
 
     def send_event(
         self,
-        room_id,
-        type,
+        room_id: str,
+        type: str,
         content: Optional[dict] = None,
-        txn_id=None,
-        tok=None,
-        expect_code=200,
+        txn_id: Optional[str] = None,
+        tok: Optional[str] = None,
+        expect_code: int = HTTPStatus.OK,
         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-    ):
+    ) -> JsonDict:
         if txn_id is None:
             txn_id = "m%s" % (str(time.time()))
 
@@ -357,11 +387,11 @@ class RestHelper:
         room_id: str,
         event_type: str,
         body: Optional[Dict[str, Any]],
-        tok: str,
-        expect_code: int = 200,
+        tok: Optional[str],
+        expect_code: int = HTTPStatus.OK,
         state_key: str = "",
         method: str = "GET",
-    ) -> Dict:
+    ) -> JsonDict:
         """Read or write some state from a given room
 
         Args:
@@ -410,9 +440,9 @@ class RestHelper:
         room_id: str,
         event_type: str,
         tok: str,
-        expect_code: int = 200,
+        expect_code: int = HTTPStatus.OK,
         state_key: str = "",
-    ):
+    ) -> JsonDict:
         """Gets some state from a room
 
         Args:
@@ -437,10 +467,10 @@ class RestHelper:
         room_id: str,
         event_type: str,
         body: Dict[str, Any],
-        tok: str,
-        expect_code: int = 200,
+        tok: Optional[str],
+        expect_code: int = HTTPStatus.OK,
         state_key: str = "",
-    ):
+    ) -> JsonDict:
         """Set some state in a room
 
         Args:
@@ -467,8 +497,8 @@ class RestHelper:
         image_data: bytes,
         tok: str,
         filename: str = "test.png",
-        expect_code: int = 200,
-    ) -> dict:
+        expect_code: int = HTTPStatus.OK,
+    ) -> JsonDict:
         """Upload a piece of test media to the media repo
         Args:
             resource: The resource that will handle the upload request
@@ -513,7 +543,7 @@ class RestHelper:
         channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
 
         # expect a confirmation page
-        assert channel.code == 200, channel.result
+        assert channel.code == HTTPStatus.OK, channel.result
 
         # fish the matrix login token out of the body of the confirmation page
         m = re.search(
@@ -532,7 +562,7 @@ class RestHelper:
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        assert channel.code == 200
+        assert channel.code == HTTPStatus.OK
         return channel.json_body
 
     def auth_via_oidc(
@@ -637,11 +667,16 @@ class RestHelper:
             (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
         ]
 
-        async def mock_req(method: str, uri: str, data=None, headers=None):
+        async def mock_req(
+            method: str,
+            uri: str,
+            data: Optional[dict] = None,
+            headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        ):
             (expected_uri, resp_obj) = expected_requests.pop(0)
             assert uri == expected_uri
             resp = FakeResponse(
-                code=200,
+                code=HTTPStatus.OK,
                 phrase=b"OK",
                 body=json.dumps(resp_obj).encode("utf-8"),
             )
@@ -739,7 +774,7 @@ class RestHelper:
             self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
         )
         # that should serve a confirmation page
-        assert channel.code == 200, channel.text_body
+        assert channel.code == HTTPStatus.OK, channel.text_body
         channel.extract_cookies(cookies)
 
         # parse the confirmation page to fish out the link.
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 4cf1ed5ddf..cba9be17c4 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -94,7 +94,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
         self.assertTrue(os.path.exists(local_path))
 
         # Asserts the file is under the expected local cache directory
-        self.assertEquals(
+        self.assertEqual(
             os.path.commonprefix([self.primary_base_path, local_path]),
             self.primary_base_path,
         )
@@ -243,7 +243,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         media_resource = hs.get_media_repository_resource()
         self.download_resource = media_resource.children[b"download"]
         self.thumbnail_resource = media_resource.children[b"thumbnail"]
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.media_repo = hs.get_media_repository()
 
         self.media_id = "example.com/12345"
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 36c495954f..02b96c9e6e 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -242,7 +242,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         return c
 
     def prepare(self, reactor, clock, hs):
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.server_notices_sender = self.hs.get_server_notices_sender()
         self.server_notices_manager = self.hs.get_server_notices_manager()
         self.event_source = self.hs.get_event_sources()
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 36c933b9e9..50c20c5b92 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -26,7 +26,7 @@ class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.user_id = self.register_user("foo", "pass")
 
     def test_background_remove_deleted_devices_from_device_inbox(self):
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 5ae491ff5a..1f6a9eb07b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -37,7 +37,7 @@ from tests import unittest
 
 class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store: EventsWorkerStore = hs.get_datastore()
+        self.store: EventsWorkerStore = hs.get_datastores().main
 
         # insert some test data
         for rid in ("room1", "room2"):
@@ -88,18 +88,18 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             res = self.get_success(
                 self.store.have_seen_events("room1", ["event10", "event19"])
             )
-            self.assertEquals(res, {"event10"})
+            self.assertEqual(res, {"event10"})
 
             # that should result in a single db query
-            self.assertEquals(ctx.get_resource_usage().db_txn_count, 1)
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
 
         # a second lookup of the same events should cause no queries
         with LoggingContext(name="test") as ctx:
             res = self.get_success(
                 self.store.have_seen_events("room1", ["event10", "event19"])
             )
-            self.assertEquals(res, {"event10"})
-            self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
+            self.assertEqual(res, {"event10"})
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
 
     def test_query_via_event_cache(self):
         # fetch an event into the event cache
@@ -108,8 +108,8 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
         # looking it up should now cause no db hits
         with LoggingContext(name="test") as ctx:
             res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
-            self.assertEquals(res, {"event10"})
-            self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)
+            self.assertEqual(res, {"event10"})
+            self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
 
 
 class EventCacheTestCase(unittest.HomeserverTestCase):
@@ -122,7 +122,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store: EventsWorkerStore = hs.get_datastore()
+        self.store: EventsWorkerStore = hs.get_datastores().main
 
         self.user = self.register_user("user", "pass")
         self.token = self.login(self.user, "pass")
@@ -163,7 +163,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase):
     """Test event fetching during a database outage."""
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
-        self.store: EventsWorkerStore = hs.get_datastore()
+        self.store: EventsWorkerStore = hs.get_datastores().main
 
         self.room_id = f"!room:{hs.hostname}"
         self.event_ids = [f"event{i}" for i in range(20)]
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index d326a1d6a6..3ac4646969 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -20,7 +20,7 @@ from tests import unittest
 
 class LockTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs: HomeServer):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def test_simple_lock(self):
         """Test that we can take out a lock and that while we hold it nobody
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 7496974da3..9abd0cb446 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -28,7 +28,7 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.user_id = self.register_user("foo", "pass")
         self.token = self.login("foo", "pass")
 
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 200b9198f9..4899cd5c36 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -20,7 +20,7 @@ from tests import unittest
 
 class UpsertManyTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.storage = hs.get_datastore()
+        self.storage = hs.get_datastores().main
 
         self.table_name = "table_" + secrets.token_hex(6)
         self.get_success(
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index d697d2bc1e..272cd35402 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -21,7 +21,7 @@ from tests import unittest
 
 class IgnoredUsersTestCase(unittest.HomeserverTestCase):
     def prepare(self, hs, reactor, clock):
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.user = "@user:test"
 
     def _update_ignore_list(
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ddcb7f5549..ee599f4336 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -88,21 +88,21 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
 
     def test_retrieve_unknown_service_token(self) -> None:
         service = self.store.get_app_service_by_token("invalid_token")
-        self.assertEquals(service, None)
+        self.assertEqual(service, None)
 
     def test_retrieval_of_service(self) -> None:
         stored_service = self.store.get_app_service_by_token(self.as_token)
         assert stored_service is not None
-        self.assertEquals(stored_service.token, self.as_token)
-        self.assertEquals(stored_service.id, self.as_id)
-        self.assertEquals(stored_service.url, self.as_url)
-        self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], [])
-        self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], [])
-        self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], [])
+        self.assertEqual(stored_service.token, self.as_token)
+        self.assertEqual(stored_service.id, self.as_id)
+        self.assertEqual(stored_service.url, self.as_url)
+        self.assertEqual(stored_service.namespaces[ApplicationService.NS_ALIASES], [])
+        self.assertEqual(stored_service.namespaces[ApplicationService.NS_ROOMS], [])
+        self.assertEqual(stored_service.namespaces[ApplicationService.NS_USERS], [])
 
     def test_retrieval_of_all_services(self) -> None:
         services = self.store.get_app_services()
-        self.assertEquals(len(services), 3)
+        self.assertEqual(len(services), 3)
 
 
 class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
@@ -182,7 +182,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
     ) -> None:
         service = Mock(id="999")
         state = self.get_success(self.store.get_appservice_state(service))
-        self.assertEquals(None, state)
+        self.assertEqual(None, state)
 
     def test_get_appservice_state_up(
         self,
@@ -194,7 +194,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         state = self.get_success(
             defer.ensureDeferred(self.store.get_appservice_state(service))
         )
-        self.assertEquals(ApplicationServiceState.UP, state)
+        self.assertEqual(ApplicationServiceState.UP, state)
 
     def test_get_appservice_state_down(
         self,
@@ -210,7 +210,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         )
         service = Mock(id=self.as_list[1]["id"])
         state = self.get_success(self.store.get_appservice_state(service))
-        self.assertEquals(ApplicationServiceState.DOWN, state)
+        self.assertEqual(ApplicationServiceState.DOWN, state)
 
     def test_get_appservices_by_state_none(
         self,
@@ -218,7 +218,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         services = self.get_success(
             self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
-        self.assertEquals(0, len(services))
+        self.assertEqual(0, len(services))
 
     def test_set_appservices_state_down(
         self,
@@ -235,7 +235,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
                 (ApplicationServiceState.DOWN.value,),
             )
         )
-        self.assertEquals(service.id, rows[0][0])
+        self.assertEqual(service.id, rows[0][0])
 
     def test_set_appservices_state_multiple_up(
         self,
@@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
                 (ApplicationServiceState.UP.value,),
             )
         )
-        self.assertEquals(service.id, rows[0][0])
+        self.assertEqual(service.id, rows[0][0])
 
     def test_create_appservice_txn_first(
         self,
@@ -267,12 +267,12 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
         txn = self.get_success(
             defer.ensureDeferred(
-                self.store.create_appservice_txn(service, events, [], [])
+                self.store.create_appservice_txn(service, events, [], [], {}, {})
             )
         )
-        self.assertEquals(txn.id, 1)
-        self.assertEquals(txn.events, events)
-        self.assertEquals(txn.service, service)
+        self.assertEqual(txn.id, 1)
+        self.assertEqual(txn.events, events)
+        self.assertEqual(txn.service, service)
 
     def test_create_appservice_txn_older_last_txn(
         self,
@@ -283,11 +283,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(self._insert_txn(service.id, 9644, events))
         self.get_success(self._insert_txn(service.id, 9645, events))
         txn = self.get_success(
-            self.store.create_appservice_txn(service, events, [], [])
+            self.store.create_appservice_txn(service, events, [], [], {}, {})
         )
-        self.assertEquals(txn.id, 9646)
-        self.assertEquals(txn.events, events)
-        self.assertEquals(txn.service, service)
+        self.assertEqual(txn.id, 9646)
+        self.assertEqual(txn.events, events)
+        self.assertEqual(txn.service, service)
 
     def test_create_appservice_txn_up_to_date_last_txn(
         self,
@@ -296,11 +296,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
         self.get_success(self._set_last_txn(service.id, 9643))
         txn = self.get_success(
-            self.store.create_appservice_txn(service, events, [], [])
+            self.store.create_appservice_txn(service, events, [], [], {}, {})
         )
-        self.assertEquals(txn.id, 9644)
-        self.assertEquals(txn.events, events)
-        self.assertEquals(txn.service, service)
+        self.assertEqual(txn.id, 9644)
+        self.assertEqual(txn.events, events)
+        self.assertEqual(txn.service, service)
 
     def test_create_appservice_txn_up_fuzzing(
         self,
@@ -320,11 +320,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
 
         txn = self.get_success(
-            self.store.create_appservice_txn(service, events, [], [])
+            self.store.create_appservice_txn(service, events, [], [], {}, {})
         )
-        self.assertEquals(txn.id, 9644)
-        self.assertEquals(txn.events, events)
-        self.assertEquals(txn.service, service)
+        self.assertEqual(txn.id, 9644)
+        self.assertEqual(txn.events, events)
+        self.assertEqual(txn.service, service)
 
     def test_complete_appservice_txn_first_txn(
         self,
@@ -346,8 +346,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
                 (service.id,),
             )
         )
-        self.assertEquals(1, len(res))
-        self.assertEquals(txn_id, res[0][0])
+        self.assertEqual(1, len(res))
+        self.assertEqual(txn_id, res[0][0])
 
         res = self.get_success(
             self.db_pool.runQuery(
@@ -357,7 +357,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
                 (txn_id,),
             )
         )
-        self.assertEquals(0, len(res))
+        self.assertEqual(0, len(res))
 
     def test_complete_appservice_txn_existing_in_state_table(
         self,
@@ -379,9 +379,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
                 (service.id,),
             )
         )
-        self.assertEquals(1, len(res))
-        self.assertEquals(txn_id, res[0][0])
-        self.assertEquals(ApplicationServiceState.UP.value, res[0][1])
+        self.assertEqual(1, len(res))
+        self.assertEqual(txn_id, res[0][0])
+        self.assertEqual(ApplicationServiceState.UP.value, res[0][1])
 
         res = self.get_success(
             self.db_pool.runQuery(
@@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
                 (txn_id,),
             )
         )
-        self.assertEquals(0, len(res))
+        self.assertEqual(0, len(res))
 
     def test_get_oldest_unsent_txn_none(
         self,
@@ -399,7 +399,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         service = Mock(id=self.as_list[0]["id"])
 
         txn = self.get_success(self.store.get_oldest_unsent_txn(service))
-        self.assertEquals(None, txn)
+        self.assertEqual(None, txn)
 
     def test_get_oldest_unsent_txn(self) -> None:
         service = Mock(id=self.as_list[0]["id"])
@@ -416,9 +416,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(self._insert_txn(service.id, 12, other_events))
 
         txn = self.get_success(self.store.get_oldest_unsent_txn(service))
-        self.assertEquals(service, txn.service)
-        self.assertEquals(10, txn.id)
-        self.assertEquals(events, txn.events)
+        self.assertEqual(service, txn.service)
+        self.assertEqual(10, txn.id)
+        self.assertEqual(events, txn.events)
 
     def test_get_appservices_by_state_single(
         self,
@@ -433,8 +433,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         services = self.get_success(
             self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
-        self.assertEquals(1, len(services))
-        self.assertEquals(self.as_list[0]["id"], services[0].id)
+        self.assertEqual(1, len(services))
+        self.assertEqual(self.as_list[0]["id"], services[0].id)
 
     def test_get_appservices_by_state_multiple(
         self,
@@ -455,8 +455,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         services = self.get_success(
             self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
-        self.assertEquals(2, len(services))
-        self.assertEquals(
+        self.assertEqual(2, len(services))
+        self.assertEqual(
             {self.as_list[2]["id"], self.as_list[0]["id"]},
             {services[0].id, services[1].id},
         )
@@ -467,7 +467,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
         self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
     ) -> None:
         self.service = Mock(id="foo")
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.get_success(
             self.store.set_appservice_state(self.service, ApplicationServiceState.UP)
         )
@@ -476,12 +476,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
         value = self.get_success(
             self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
         )
-        self.assertEquals(value, 0)
+        self.assertEqual(value, 0)
 
         value = self.get_success(
             self.store.get_type_stream_id_for_appservice(self.service, "presence")
         )
-        self.assertEquals(value, 0)
+        self.assertEqual(value, 0)
 
     def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
         self.get_failure(
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 6156dfac4e..39dcc094bd 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -24,7 +24,7 @@ from tests.test_utils import make_awaitable, simple_async_mock
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
+        self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
@@ -42,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         # the target runtime for each bg update
         target_background_update_duration_ms = 100
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "background_updates",
@@ -102,7 +102,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
 class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.updates: BackgroundUpdater = self.hs.get_datastore().db_pool.updates
+        self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
@@ -138,7 +138,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         )
 
     def test_controller(self):
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
                 "background_updates",
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 3e4f0579c9..a8ffb52c05 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -103,7 +103,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertEquals("Value", value)
+        self.assertEqual("Value", value)
         self.mock_txn.execute.assert_called_with(
             "SELECT retcol FROM tablename WHERE keycol = ?", ["TheKey"]
         )
@@ -121,7 +121,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
+        self.assertEqual({"colA": 1, "colB": 2, "colC": 3}, ret)
         self.mock_txn.execute.assert_called_with(
             "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
         )
@@ -154,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             )
         )
 
-        self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
+        self.assertEqual([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
         self.mock_txn.execute.assert_called_with(
             "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
         )
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index a59c28f896..ce89c96912 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -30,7 +30,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
     """
 
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
         self.room_creator = homeserver.get_room_creation_handler()
 
         # Create a test user and room
@@ -242,7 +242,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         return self.setup_test_homeserver(config=config)
 
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
         self.room_creator = homeserver.get_room_creation_handler()
         self.event_creator_handler = homeserver.get_event_creation_handler()
 
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index c8ac67e35b..49ad3c1324 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -35,7 +35,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, hs, reactor, clock):
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
     def test_insert_new_client_ip(self):
         self.reactor.advance(12345678)
@@ -666,7 +666,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, hs, reactor, clock):
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
         self.user_id = self.register_user("bob", "abc123", True)
 
     def test_request_with_xforwarded(self):
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index b547bf8d99..21ffc5a909 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase
 
 class DeviceStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def test_store_new_device(self):
         self.get_success(
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 43628ce44f..20bf3ca17b 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -19,7 +19,7 @@ from tests.unittest import HomeserverTestCase
 
 class DirectoryStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.room = RoomID.from_string("!abcde:test")
         self.alias = RoomAlias.from_string("#my-room:test")
@@ -31,7 +31,7 @@ class DirectoryStoreTestCase(HomeserverTestCase):
             )
         )
 
-        self.assertEquals(
+        self.assertEqual(
             ["#my-room:test"],
             (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
         )
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 7556171d8a..fb96ab3a2f 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -28,7 +28,7 @@ room_key: RoomKey = {
 class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver("server", federation_http_client=None)
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         return hs
 
     def test_room_keys_version_delete(self):
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3bf6e337f4..0f04493ad0 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -17,7 +17,7 @@ from tests.unittest import HomeserverTestCase
 
 class EndToEndKeyStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def test_key_without_device_name(self):
         now = 1470174257070
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e3273a93f9..401020fd63 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -30,7 +30,7 @@ from tests.unittest import HomeserverTestCase
 
 class EventChainStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self._next_stream_ordering = 1
 
     def test_simple(self):
@@ -492,7 +492,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.user_id = self.register_user("foo", "pass")
         self.token = self.login("foo", "pass")
         self.requester = create_requester(self.user_id)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 667ca90a4d..645d564d1c 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -31,7 +31,7 @@ import tests.utils
 
 class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
     def test_get_prev_events_for_room(self):
         room_id = "@ROOM:local"
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 738f3ad1dc..0f9add4841 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -30,7 +30,7 @@ HIGHLIGHT = [
 
 class EventPushActionsStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.persist_events_store = hs.get_datastores().persist_events
 
     def test_get_unread_push_actions_for_user_in_range_for_http(self):
@@ -57,7 +57,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
                     "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
                 )
             )
-            self.assertEquals(
+            self.assertEqual(
                 counts,
                 NotifCounts(
                     notify_count=noitf_count,
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index a8639d8f82..ef5e25873c 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -32,7 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
         self.state = self.hs.get_state_handler()
         self.persistence = self.hs.get_storage().persistence
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         self.register_user("user", "pass")
         self.token = self.login("user", "pass")
@@ -341,7 +341,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
         self.state = self.hs.get_state_handler()
         self.persistence = self.hs.get_storage().persistence
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
     def test_remote_user_rooms_cache_invalidated(self):
         """Test that if the server leaves a room the `get_rooms_for_user` cache
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 7486078284..6ac4b93f98 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -26,7 +26,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         skip = "Requires Postgres"
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
@@ -459,7 +459,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         skip = "Requires Postgres"
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
@@ -585,7 +585,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         skip = "Requires Postgres"
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index a94b5fd721..9059095525 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -37,7 +37,7 @@ KEY_2 = decode_verify_key_base64(
 
 class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
     def test_get_server_verify_keys(self):
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:KEY_ID_2"
@@ -74,7 +74,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
     def test_cache(self):
         """Check that updates correctly invalidate the cache."""
 
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:key2"
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index f8d11bac4e..5806cb0e4b 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -22,7 +22,7 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
     def setUp(self) -> None:
         super(DataStoreTestCase, self).setUp()
 
-        self.store = self.hs.get_datastore()
+        self.store = self.hs.get_datastores().main
 
         self.user = UserID.from_string("@abcde:test")
         self.displayname = "Frank"
@@ -38,12 +38,12 @@ class DataStoreTestCase(unittest.HomeserverTestCase):
             self.store.get_users_paginate(0, 10, name="bc", guests=False)
         )
 
-        self.assertEquals(1, total)
-        self.assertEquals(self.displayname, users.pop()["displayname"])
+        self.assertEqual(1, total)
+        self.assertEqual(self.displayname, users.pop()["displayname"])
 
         users, total = self.get_success(
             self.store.get_users_paginate(0, 10, name="BC", guests=False)
         )
 
-        self.assertEquals(1, total)
-        self.assertEquals(self.displayname, users.pop()["displayname"])
+        self.assertEqual(1, total)
+        self.assertEqual(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index d6b4cdd788..79648d45db 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -45,7 +45,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
         # Advance the clock a bit
         reactor.advance(FORTY_DAYS)
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index d37736edf8..a019d06e09 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -22,7 +22,7 @@ from tests import unittest
 
 class ProfileStoreTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.u_frank = UserID.from_string("@frank:test")
 
@@ -33,7 +33,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
             self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
         )
 
-        self.assertEquals(
+        self.assertEqual(
             "Frank",
             (
                 self.get_success(
@@ -60,7 +60,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals(
+        self.assertEqual(
             "http://my.site/here",
             (
                 self.get_success(
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 22a77c3ccc..08cc60237e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -30,7 +30,7 @@ class PurgeTests(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.room_id = self.helper.create_room_as(self.user_id)
 
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = self.hs.get_storage()
 
     def test_purge_history(self):
@@ -47,7 +47,7 @@ class PurgeTests(HomeserverTestCase):
         token = self.get_success(
             self.store.get_topological_token_for_event(last["event_id"])
         )
-        token_str = self.get_success(token.to_string(self.hs.get_datastore()))
+        token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
 
         # Purge everything before this topological token
         self.get_success(
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 8c95a0a2fb..03e9cc7d4a 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -30,7 +30,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 9748065282..a49ac1525e 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase
 
 class RegistrationStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.user_id = "@my-user:test"
         self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"]
@@ -30,7 +30,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
     def test_register(self):
         self.get_success(self.store.register_user(self.user_id, self.pwhash))
 
-        self.assertEquals(
+        self.assertEqual(
             {
                 # TODO(paul): Surely this field should be 'user_id', not 'name'
                 "name": self.user_id,
@@ -131,7 +131,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
             ),
             ThreepidValidationError,
         )
-        self.assertEquals(e.value.msg, "Unknown session_id", e)
+        self.assertEqual(e.value.msg, "Unknown session_id", e)
 
         # Set the config setting to true.
         self.store._ignore_unknown_session_error = True
@@ -146,4 +146,4 @@ class RegistrationStoreTestCase(HomeserverTestCase):
             ),
             ThreepidValidationError,
         )
-        self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
+        self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py
index cfc8098af6..0baa54312e 100644
--- a/tests/storage/test_rollback_worker.py
+++ b/tests/storage/test_rollback_worker.py
@@ -56,7 +56,7 @@ class WorkerSchemaTests(HomeserverTestCase):
     def test_rolling_back(self):
         """Test that workers can start if the DB is a newer schema version"""
 
-        db_pool = self.hs.get_datastore().db_pool
+        db_pool = self.hs.get_datastores().main.db_pool
         db_conn = LoggingDatabaseConnection(
             db_pool._db_pool.connect(),
             db_pool.engine,
@@ -72,7 +72,7 @@ class WorkerSchemaTests(HomeserverTestCase):
 
     def test_not_upgraded_old_schema_version(self):
         """Test that workers don't start if the DB has an older schema version"""
-        db_pool = self.hs.get_datastore().db_pool
+        db_pool = self.hs.get_datastores().main.db_pool
         db_conn = LoggingDatabaseConnection(
             db_pool._db_pool.connect(),
             db_pool.engine,
@@ -92,7 +92,7 @@ class WorkerSchemaTests(HomeserverTestCase):
         Test that workers don't start if the DB is on the current schema version,
         but there are still outstanding delta migrations to run.
         """
-        db_pool = self.hs.get_datastore().db_pool
+        db_pool = self.hs.get_datastores().main.db_pool
         db_conn = LoggingDatabaseConnection(
             db_pool._db_pool.connect(),
             db_pool.engine,
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 31ce7f6252..5b011e18cd 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -23,7 +23,7 @@ class RoomStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         # We can't test RoomStore on its own without the DirectoryStore, for
         # management of the 'room_aliases' table
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.room = RoomID.from_string("!abcde:test")
         self.alias = RoomAlias.from_string("#a-room-name:test")
@@ -71,7 +71,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.event_factory = hs.get_event_factory()
 
@@ -104,7 +104,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
             self.store.get_current_state(room_id=self.room.to_string())
         )
 
-        self.assertEquals(1, len(state))
+        self.assertEqual(1, len(state))
         self.assertObjectHasAttributes(
             {"type": "m.room.name", "room_id": self.room.to_string(), "name": name},
             state[0],
@@ -121,7 +121,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
             self.store.get_current_state(room_id=self.room.to_string())
         )
 
-        self.assertEquals(1, len(state))
+        self.assertEqual(1, len(state))
         self.assertObjectHasAttributes(
             {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic},
             state[0],
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 8971ecccbd..8dfc1e1db9 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -13,13 +13,16 @@
 # limitations under the License.
 
 import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.api.errors import StoreError
 from synapse.rest.client import login, room
 from synapse.storage.engines import PostgresEngine
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, skip_unless
+from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
-class NullByteInsertionTest(HomeserverTestCase):
+class EventSearchInsertionTest(HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -46,11 +49,11 @@ class NullByteInsertionTest(HomeserverTestCase):
             self.assertIn("event_id", response)
 
         # Check that search works for the message where the null byte was replaced
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         result = self.get_success(
             store.search_msgs([room_id], "hi bob", ["content.body"])
         )
-        self.assertEquals(result.get("count"), 1)
+        self.assertEqual(result.get("count"), 1)
         if isinstance(store.database_engine, PostgresEngine):
             self.assertIn("hi", result.get("highlights"))
             self.assertIn("bob", result.get("highlights"))
@@ -59,16 +62,126 @@ class NullByteInsertionTest(HomeserverTestCase):
         result = self.get_success(
             store.search_msgs([room_id], "another", ["content.body"])
         )
-        self.assertEquals(result.get("count"), 1)
+        self.assertEqual(result.get("count"), 1)
         if isinstance(store.database_engine, PostgresEngine):
             self.assertIn("another", result.get("highlights"))
 
         # Check that search works for a search term that overlaps with the message
         # containing a null byte and an unrelated message.
         result = self.get_success(store.search_msgs([room_id], "hi", ["content.body"]))
-        self.assertEquals(result.get("count"), 2)
+        self.assertEqual(result.get("count"), 2)
         result = self.get_success(
             store.search_msgs([room_id], "hi alice", ["content.body"])
         )
         if isinstance(store.database_engine, PostgresEngine):
             self.assertIn("alice", result.get("highlights"))
+
+    def test_non_string(self):
+        """Test that non-string `value`s are not inserted into `event_search`.
+
+        This is particularly important when using sqlite, since a sqlite column can hold
+        both strings and integers. When using Postgres, integers are automatically
+        converted to strings.
+
+        Regression test for #11918.
+        """
+        store = self.hs.get_datastores().main
+
+        # Register a user and create a room
+        user_id = self.register_user("alice", "password")
+        access_token = self.login("alice", "password")
+        room_id = self.helper.create_room_as("alice", tok=access_token)
+        room_version = self.get_success(store.get_room_version(room_id))
+
+        # Construct a message with a numeric body to be received over federation
+        # The message can't be sent using the client API, since Synapse's event
+        # validation will reject it.
+        prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+        prev_event = self.get_success(store.get_event(prev_event_ids[0]))
+        prev_state_map = self.get_success(
+            self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
+        )
+
+        event_dict = {
+            "type": EventTypes.Message,
+            "content": {"msgtype": "m.text", "body": 2},
+            "room_id": room_id,
+            "sender": user_id,
+            "depth": prev_event.depth + 1,
+            "prev_events": prev_event_ids,
+            "origin_server_ts": self.clock.time_msec(),
+        }
+        builder = self.hs.get_event_builder_factory().for_room_version(
+            room_version, event_dict
+        )
+        event = self.get_success(
+            builder.build(
+                prev_event_ids=prev_event_ids,
+                auth_event_ids=self.hs.get_event_auth_handler().compute_auth_events(
+                    builder,
+                    prev_state_map,
+                    for_verification=False,
+                ),
+                depth=event_dict["depth"],
+            )
+        )
+
+        # Receive the event
+        self.get_success(
+            self.hs.get_federation_event_handler().on_receive_pdu(
+                self.hs.hostname, event
+            )
+        )
+
+        # The event should not have an entry in the `event_search` table
+        f = self.get_failure(
+            store.db_pool.simple_select_one_onecol(
+                "event_search",
+                {"room_id": room_id, "event_id": event.event_id},
+                "event_id",
+            ),
+            StoreError,
+        )
+        self.assertEqual(f.value.code, 404)
+
+    @skip_unless(not USE_POSTGRES_FOR_TESTS, "requires sqlite")
+    def test_sqlite_non_string_deletion_background_update(self):
+        """Test the background update to delete bad rows from `event_search`."""
+        store = self.hs.get_datastores().main
+
+        # Populate `event_search` with dummy data
+        self.get_success(
+            store.db_pool.simple_insert_many(
+                "event_search",
+                keys=["event_id", "room_id", "key", "value"],
+                values=[
+                    ("event1", "room_id", "content.body", "hi"),
+                    ("event2", "room_id", "content.body", "2"),
+                    ("event3", "room_id", "content.body", 3),
+                ],
+                desc="populate_event_search",
+            )
+        )
+
+        # Run the background update
+        store.db_pool.updates._all_done = False
+        self.get_success(
+            store.db_pool.simple_insert(
+                "background_updates",
+                {
+                    "update_name": "event_search_sqlite_delete_non_strings",
+                    "progress_json": "{}",
+                },
+            )
+        )
+        self.wait_for_background_updates()
+
+        # The non-string `value`s ought to be gone now.
+        values = self.get_success(
+            store.db_pool.simple_select_onecol(
+                "event_search",
+                {"room_id": "room_id"},
+                "value",
+            ),
+        )
+        self.assertCountEqual(values, ["hi", "2"])
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5cfdfe9b85..b8f09a8ee0 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -35,7 +35,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         self.u_alice = self.register_user("alice", "pass")
         self.t_alice = self.login("alice", "pass")
@@ -55,7 +55,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.assertEquals([self.room], [m.room_id for m in rooms_for_user])
+        self.assertEqual([self.room], [m.room_id for m in rooms_for_user])
 
     def test_count_known_servers(self):
         """
@@ -212,7 +212,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 
 class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
         self.room_creator = homeserver.get_room_creation_handler()
 
     def test_can_rerun_update(self):
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 28c767ecfd..f88f1c55fc 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
 
 class StateStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_datastore = self.storage.state.stores.state
         self.event_builder_factory = hs.get_event_builder_factory()
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index ce782c7e1d..6a1cf33054 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -115,7 +115,7 @@ class PaginationTestCase(HomeserverTestCase):
         )
 
         events, next_key = self.get_success(
-            self.hs.get_datastore().paginate_room_events(
+            self.hs.get_datastores().main.paginate_room_events(
                 room_id=self.room_id,
                 from_key=from_token.room_key,
                 to_key=None,
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index bea9091d30..e05daa285e 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -20,7 +20,7 @@ from tests.unittest import HomeserverTestCase
 
 class TransactionStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
 
     def test_get_set_transactions(self):
         """Tests that we can successfully get a non-existent entry for
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 48f1e9d841..7f1964eb6a 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -149,7 +149,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
         self.user_dir_helper = GetUserDirectoryTables(self.store)
 
     def _purge_and_rebuild_user_dir(self) -> None:
@@ -415,7 +415,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
 
 class UserDirectoryStoreTestCase(HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastore()
+        self.store = hs.get_datastores().main
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index f8341041ee..31546ea52b 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -48,7 +48,7 @@ class DistributorTestCase(unittest.TestCase):
             observers[0].assert_called_once_with("Go")
             observers[1].assert_called_once_with("Go")
 
-            self.assertEquals(mock_logger.warning.call_count, 1)
+            self.assertEqual(mock_logger.warning.call_count, 1)
             self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
 
     def test_signal_prereg(self):
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 2b9804aba0..c39816de85 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -52,11 +52,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             )
         )[0]["room_id"]
 
-        self.store = self.homeserver.get_datastore()
+        self.store = self.homeserver.get_datastores().main
 
         # Figure out what the most recent event is
         most_recent = self.get_success(
-            self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
+            self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
+                self.room_id
+            )
         )[0]
 
         join_event = make_event_from_dict(
@@ -185,7 +187,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
-        store = self.homeserver.get_datastore()
+        store = self.homeserver.get_datastores().main
         store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
 
         # Manually inject a fake device list update. We need this update to include at
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 80ab40e255..46bd3075de 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -52,7 +52,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastore()
+        self.store = homeserver.get_datastores().main
 
     def test_simple_deny_mau(self):
         # Create and sync so that the MAU counts get updated
diff --git a/tests/test_state.py b/tests/test_state.py
index 76e0e8ca7f..e4baa69137 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional
+from typing import Collection, Dict, List, Optional
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -70,7 +70,7 @@ def create_event(
     return event
 
 
-class StateGroupStore:
+class _DummyStore:
     def __init__(self):
         self._event_to_state_group = {}
         self._group_to_state = {}
@@ -105,6 +105,11 @@ class StateGroupStore:
             if e_id in self._event_id_to_event
         }
 
+    async def get_partial_state_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, bool]:
+        return {e: False for e in event_ids}
+
     async def get_state_group_delta(self, name):
         return None, None
 
@@ -157,12 +162,12 @@ class Graph:
 
 class StateTestCase(unittest.TestCase):
     def setUp(self):
-        self.store = StateGroupStore()
-        storage = Mock(main=self.store, state=self.store)
+        self.dummy_store = _DummyStore()
+        storage = Mock(main=self.dummy_store, state=self.dummy_store)
         hs = Mock(
             spec_set=[
                 "config",
-                "get_datastore",
+                "get_datastores",
                 "get_storage",
                 "get_auth",
                 "get_state_handler",
@@ -173,7 +178,7 @@ class StateTestCase(unittest.TestCase):
             ]
         )
         hs.config = default_config("tesths", True)
-        hs.get_datastore.return_value = self.store
+        hs.get_datastores.return_value = Mock(main=self.dummy_store)
         hs.get_state_handler.return_value = None
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
@@ -198,7 +203,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store: dict[str, EventContext] = {}
 
@@ -206,7 +211,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         ctx_c = context_store["C"]
@@ -242,7 +247,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -250,7 +255,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # C ends up winning the resolution between B and C
@@ -300,7 +305,7 @@ class StateTestCase(unittest.TestCase):
             edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
         )
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -308,7 +313,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # C ends up winning the resolution between C and D because bans win over other
@@ -375,7 +380,7 @@ class StateTestCase(unittest.TestCase):
         self._add_depths(nodes, edges)
         graph = Graph(nodes, edges)
 
-        self.store.register_events(graph.walk())
+        self.dummy_store.register_events(graph.walk())
 
         context_store = {}
 
@@ -383,7 +388,7 @@ class StateTestCase(unittest.TestCase):
             context = yield defer.ensureDeferred(
                 self.state.compute_event_context(event)
             )
-            self.store.register_event_context(event, context)
+            self.dummy_store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         # B ends up winning the resolution between B and C because power levels
@@ -476,7 +481,7 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id,
                 event.room_id,
                 None,
@@ -484,7 +489,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id, group_name)
+        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
@@ -510,7 +515,7 @@ class StateTestCase(unittest.TestCase):
         ]
 
         group_name = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id,
                 event.room_id,
                 None,
@@ -518,7 +523,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id, group_name)
+        self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
@@ -554,8 +559,8 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        self.store.register_events(old_state_1)
-        self.store.register_events(old_state_2)
+        self.dummy_store.register_events(old_state_1)
+        self.dummy_store.register_events(old_state_2)
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -594,10 +599,10 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
-        store = StateGroupStore()
+        store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -649,10 +654,10 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=2),
         ]
 
-        store = StateGroupStore()
+        store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events
 
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
@@ -695,7 +700,7 @@ class StateTestCase(unittest.TestCase):
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
         sg1 = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id_1,
                 event.room_id,
                 None,
@@ -703,10 +708,10 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state_1},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id_1, sg1)
+        self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
 
         sg2 = yield defer.ensureDeferred(
-            self.store.store_state_group(
+            self.dummy_store.store_state_group(
                 prev_event_id_2,
                 event.room_id,
                 None,
@@ -714,7 +719,7 @@ class StateTestCase(unittest.TestCase):
                 {(e.type, e.state_key): e.event_id for e in old_state_2},
             )
         )
-        self.store.register_event_id_state_group(prev_event_id_2, sg2)
+        self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
 
         result = yield defer.ensureDeferred(self.state.compute_event_context(event))
         return result
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 67dcf567cd..37fada5c53 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -54,7 +54,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
         channel = self.make_request(b"POST", self.url, request_data)
 
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
 
         self.assertTrue(channel.json_body is not None)
         self.assertIsInstance(channel.json_body["session"], str)
@@ -99,7 +99,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
 
         # We don't bother checking that the response is correct - we'll leave that to
         # other tests. We just want to make sure we're on the right path.
-        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.result["code"], b"401", channel.result)
 
         # Finish the UI auth for terms
         request_data = json.dumps(
@@ -117,7 +117,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
         # We're interested in getting a response that looks like a successful
         # registration, not so much that the details are exactly what we want.
 
-        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
 
         self.assertTrue(channel.json_body is not None)
         self.assertIsInstance(channel.json_body["user_id"], str)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index f2ef1c6051..d04bcae0fa 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -25,7 +25,7 @@ class MockClockTestCase(unittest.TestCase):
 
         self.clock.advance_time(20)
 
-        self.assertEquals(20, self.clock.time() - start_time)
+        self.assertEqual(20, self.clock.time() - start_time)
 
     def test_later(self):
         invoked = [0, 0]
diff --git a/tests/test_types.py b/tests/test_types.py
index 0d0c00d97a..80888a744d 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -22,9 +22,9 @@ class UserIDTestCase(unittest.HomeserverTestCase):
     def test_parse(self):
         user = UserID.from_string("@1234abcd:test")
 
-        self.assertEquals("1234abcd", user.localpart)
-        self.assertEquals("test", user.domain)
-        self.assertEquals(True, self.hs.is_mine(user))
+        self.assertEqual("1234abcd", user.localpart)
+        self.assertEqual("test", user.domain)
+        self.assertEqual(True, self.hs.is_mine(user))
 
     def test_pase_empty(self):
         with self.assertRaises(SynapseError):
@@ -33,7 +33,7 @@ class UserIDTestCase(unittest.HomeserverTestCase):
     def test_build(self):
         user = UserID("5678efgh", "my.domain")
 
-        self.assertEquals(user.to_string(), "@5678efgh:my.domain")
+        self.assertEqual(user.to_string(), "@5678efgh:my.domain")
 
     def test_compare(self):
         userA = UserID.from_string("@userA:my.domain")
@@ -48,14 +48,14 @@ class RoomAliasTestCase(unittest.HomeserverTestCase):
     def test_parse(self):
         room = RoomAlias.from_string("#channel:test")
 
-        self.assertEquals("channel", room.localpart)
-        self.assertEquals("test", room.domain)
-        self.assertEquals(True, self.hs.is_mine(room))
+        self.assertEqual("channel", room.localpart)
+        self.assertEqual("test", room.domain)
+        self.assertEqual(True, self.hs.is_mine(room))
 
     def test_build(self):
         room = RoomAlias("channel", "my.domain")
 
-        self.assertEquals(room.to_string(), "#channel:my.domain")
+        self.assertEqual(room.to_string(), "#channel:my.domain")
 
     def test_validate(self):
         id_string = "#test:domain,test"
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index e9ec9e085b..c654e36ee4 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -85,7 +85,9 @@ async def create_event(
     **kwargs,
 ) -> Tuple[EventBase, EventContext]:
     if room_version is None:
-        room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
+        room_version = await hs.get_datastores().main.get_room_version_id(
+            kwargs["room_id"]
+        )
 
     builder = hs.get_event_builder_factory().for_room_version(
         KNOWN_ROOM_VERSIONS[room_version], kwargs
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e0b08d67d4..219b5660b1 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -93,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         events_to_filter.append(evt)
 
         # the erasey user gets erased
-        self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs"))
+        self.get_success(
+            self.hs.get_datastores().main.mark_user_erased("@erased:local_hs")
+        )
 
         # ... and the filtering happens.
         filtered = self.get_success(
diff --git a/tests/unittest.py b/tests/unittest.py
index 7983c1e8b8..326895f4c9 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -152,12 +152,12 @@ class TestCase(unittest.TestCase):
 
     def assertObjectHasAttributes(self, attrs, obj):
         """Asserts that the given object has each of the attributes given, and
-        that the value of each matches according to assertEquals."""
+        that the value of each matches according to assertEqual."""
         for key in attrs.keys():
             if not hasattr(obj, key):
                 raise AssertionError("Expected obj to have a '.%s'" % key)
             try:
-                self.assertEquals(attrs[key], getattr(obj, key))
+                self.assertEqual(attrs[key], getattr(obj, key))
             except AssertionError as e:
                 raise (type(e))(f"Assert error for '.{key}':") from e
 
@@ -169,7 +169,7 @@ class TestCase(unittest.TestCase):
             actual (dict): The test result. Extra keys will not be checked.
         """
         for key in required:
-            self.assertEquals(
+            self.assertEqual(
                 required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
             )
 
@@ -280,7 +280,7 @@ class HomeserverTestCase(TestCase):
 
                 # We need a valid token ID to satisfy foreign key constraints.
                 token_id = self.get_success(
-                    self.hs.get_datastore().add_access_token_to_user(
+                    self.hs.get_datastores().main.add_access_token_to_user(
                         self.helper.auth_user_id,
                         "some_fake_token",
                         None,
@@ -337,7 +337,7 @@ class HomeserverTestCase(TestCase):
 
     def wait_for_background_updates(self) -> None:
         """Block until all background database updates have completed."""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         while not self.get_success(
             store.db_pool.updates.has_completed_background_updates()
         ):
@@ -504,7 +504,7 @@ class HomeserverTestCase(TestCase):
                 self.get_success(stor.db_pool.updates.run_background_updates(False))
 
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
-        stor = hs.get_datastore()
+        stor = hs.get_datastores().main
 
         # Run the database background updates, when running against "master".
         if hs.__class__.__name__ == "TestHomeServer":
@@ -722,14 +722,16 @@ class HomeserverTestCase(TestCase):
         Add the given event as an extremity to the room.
         """
         self.get_success(
-            self.hs.get_datastore().db_pool.simple_insert(
+            self.hs.get_datastores().main.db_pool.simple_insert(
                 table="event_forward_extremities",
                 values={"room_id": room_id, "event_id": event_id},
                 desc="test_add_extremity",
             )
         )
 
-        self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,))
+        self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate(
+            (room_id,)
+        )
 
     def attempt_wrong_password_login(self, username, password):
         """Attempts to login as the user with the given password, asserting
@@ -775,7 +777,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
         verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
 
         self.get_success(
-            hs.get_datastore().store_server_verify_keys(
+            hs.get_datastores().main.store_server_verify_keys(
                 from_server=self.OTHER_SERVER_NAME,
                 ts_added_ms=clock.time_msec(),
                 verify_keys=[
diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py
index c613ce3f10..02b99b466a 100644
--- a/tests/util/caches/test_deferred_cache.py
+++ b/tests/util/caches/test_deferred_cache.py
@@ -31,7 +31,7 @@ class DeferredCacheTestCase(TestCase):
         cache = DeferredCache("test")
         cache.prefill("foo", 123)
 
-        self.assertEquals(self.successResultOf(cache.get("foo")), 123)
+        self.assertEqual(self.successResultOf(cache.get("foo")), 123)
 
     def test_hit_deferred(self):
         cache = DeferredCache("test")
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index ced3efd93f..19741ffcda 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -434,8 +434,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
 
         a = A()
 
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals((yield a.func("bar")), "bar")
+        self.assertEqual((yield a.func("foo")), "foo")
+        self.assertEqual((yield a.func("bar")), "bar")
 
     @defer.inlineCallbacks
     def test_hit(self):
@@ -450,10 +450,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         a = A()
         yield a.func("foo")
 
-        self.assertEquals(callcount[0], 1)
+        self.assertEqual(callcount[0], 1)
 
-        self.assertEquals((yield a.func("foo")), "foo")
-        self.assertEquals(callcount[0], 1)
+        self.assertEqual((yield a.func("foo")), "foo")
+        self.assertEqual(callcount[0], 1)
 
     @defer.inlineCallbacks
     def test_invalidate(self):
@@ -468,13 +468,13 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         a = A()
         yield a.func("foo")
 
-        self.assertEquals(callcount[0], 1)
+        self.assertEqual(callcount[0], 1)
 
         a.func.invalidate(("foo",))
 
         yield a.func("foo")
 
-        self.assertEquals(callcount[0], 2)
+        self.assertEqual(callcount[0], 2)
 
     def test_invalidate_missing(self):
         class A:
@@ -499,7 +499,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         for k in range(0, 12):
             yield a.func(k)
 
-        self.assertEquals(callcount[0], 12)
+        self.assertEqual(callcount[0], 12)
 
         # There must have been at least 2 evictions, meaning if we calculate
         # all 12 values again, we must get called at least 2 more times
@@ -525,8 +525,8 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
 
         a.func.prefill(("foo",), 456)
 
-        self.assertEquals(a.func("foo").result, 456)
-        self.assertEquals(callcount[0], 0)
+        self.assertEqual(a.func("foo").result, 456)
+        self.assertEqual(callcount[0], 0)
 
     @defer.inlineCallbacks
     def test_invalidate_context(self):
@@ -547,19 +547,19 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         a = A()
         yield a.func2("foo")
 
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
+        self.assertEqual(callcount[0], 1)
+        self.assertEqual(callcount2[0], 1)
 
         a.func.invalidate(("foo",))
         yield a.func("foo")
 
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 1)
+        self.assertEqual(callcount[0], 2)
+        self.assertEqual(callcount2[0], 1)
 
         yield a.func2("foo")
 
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
+        self.assertEqual(callcount[0], 2)
+        self.assertEqual(callcount2[0], 2)
 
     @defer.inlineCallbacks
     def test_eviction_context(self):
@@ -581,22 +581,22 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         yield a.func2("foo")
         yield a.func2("foo2")
 
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
+        self.assertEqual(callcount[0], 2)
+        self.assertEqual(callcount2[0], 2)
 
         yield a.func2("foo")
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
+        self.assertEqual(callcount[0], 2)
+        self.assertEqual(callcount2[0], 2)
 
         yield a.func("foo3")
 
-        self.assertEquals(callcount[0], 3)
-        self.assertEquals(callcount2[0], 2)
+        self.assertEqual(callcount[0], 3)
+        self.assertEqual(callcount2[0], 2)
 
         yield a.func2("foo")
 
-        self.assertEquals(callcount[0], 4)
-        self.assertEquals(callcount2[0], 3)
+        self.assertEqual(callcount[0], 4)
+        self.assertEqual(callcount2[0], 3)
 
     @defer.inlineCallbacks
     def test_double_get(self):
@@ -619,30 +619,30 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
 
         yield a.func2("foo")
 
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 1)
+        self.assertEqual(callcount[0], 1)
+        self.assertEqual(callcount2[0], 1)
 
         a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1)
+        self.assertEqual(a.func2.cache.cache.del_multi.call_count, 1)
 
         yield a.func2("foo")
         a.func2.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2)
+        self.assertEqual(a.func2.cache.cache.del_multi.call_count, 2)
 
-        self.assertEquals(callcount[0], 1)
-        self.assertEquals(callcount2[0], 2)
+        self.assertEqual(callcount[0], 1)
+        self.assertEqual(callcount2[0], 2)
 
         a.func.invalidate(("foo",))
-        self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3)
+        self.assertEqual(a.func2.cache.cache.del_multi.call_count, 3)
         yield a.func("foo")
 
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 2)
+        self.assertEqual(callcount[0], 2)
+        self.assertEqual(callcount2[0], 2)
 
         yield a.func2("foo")
 
-        self.assertEquals(callcount[0], 2)
-        self.assertEquals(callcount2[0], 3)
+        self.assertEqual(callcount[0], 2)
+        self.assertEqual(callcount2[0], 3)
 
 
 class CachedListDescriptorTestCase(unittest.TestCase):
@@ -673,14 +673,14 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             self.assertEqual(current_context(), SENTINEL_CONTEXT)
             r = yield d1
             self.assertEqual(current_context(), c1)
-            obj.mock.assert_called_once_with((10, 20), 2)
+            obj.mock.assert_called_once_with({10, 20}, 2)
             self.assertEqual(r, {10: "fish", 20: "chips"})
             obj.mock.reset_mock()
 
             # a call with different params should call the mock again
             obj.mock.return_value = {30: "peas"}
             r = yield obj.list_fn([20, 30], 2)
-            obj.mock.assert_called_once_with((30,), 2)
+            obj.mock.assert_called_once_with({30}, 2)
             self.assertEqual(r, {20: "chips", 30: "peas"})
             obj.mock.reset_mock()
 
@@ -701,7 +701,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             obj.mock.return_value = {40: "gravy"}
             iterable = (x for x in [10, 40, 40])
             r = yield obj.list_fn(iterable, 2)
-            obj.mock.assert_called_once_with((40,), 2)
+            obj.mock.assert_called_once_with({40}, 2)
             self.assertEqual(r, {10: "fish", 40: "gravy"})
 
     def test_concurrent_lookups(self):
@@ -729,7 +729,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         d3 = obj.list_fn([10])
 
         # the mock should have been called exactly once
-        obj.mock.assert_called_once_with((10,))
+        obj.mock.assert_called_once_with({10})
         obj.mock.reset_mock()
 
         # ... and none of the calls should yet be complete
@@ -771,7 +771,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         # cache miss
         obj.mock.return_value = {10: "fish", 20: "chips"}
         r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
-        obj.mock.assert_called_once_with((10, 20), 2)
+        obj.mock.assert_called_once_with({10, 20}, 2)
         self.assertEqual(r1, {10: "fish", 20: "chips"})
         obj.mock.reset_mock()
 
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index ab89cab812..362014f4cb 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -11,9 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import traceback
+
 from twisted.internet import defer
-from twisted.internet.defer import CancelledError, Deferred
+from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
 from twisted.internet.task import Clock
+from twisted.python.failure import Failure
 
 from synapse.logging.context import (
     SENTINEL_CONTEXT,
@@ -21,7 +24,12 @@ from synapse.logging.context import (
     PreserveLoggingContext,
     current_context,
 )
-from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
+from synapse.util.async_helpers import (
+    ObservableDeferred,
+    concurrently_execute,
+    stop_cancellation,
+    timeout_deferred,
+)
 
 from tests.unittest import TestCase
 
@@ -171,3 +179,151 @@ class TimeoutDeferredTest(TestCase):
             )
             self.failureResultOf(timing_out_d, defer.TimeoutError)
             self.assertIs(current_context(), context_one)
+
+
+class _TestException(Exception):
+    pass
+
+
+class ConcurrentlyExecuteTest(TestCase):
+    def test_limits_runners(self):
+        """If we have more tasks than runners, we should get the limit of runners"""
+        started = 0
+        waiters = []
+        processed = []
+
+        async def callback(v):
+            # when we first enter, bump the start count
+            nonlocal started
+            started += 1
+
+            # record the fact we got an item
+            processed.append(v)
+
+            # wait for the goahead before returning
+            d2 = Deferred()
+            waiters.append(d2)
+            await d2
+
+        # set it going
+        d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
+
+        # check we got exactly 3 processes
+        self.assertEqual(started, 3)
+        self.assertEqual(len(waiters), 3)
+
+        # let one finish
+        waiters.pop().callback(0)
+
+        # ... which should start another
+        self.assertEqual(started, 4)
+        self.assertEqual(len(waiters), 3)
+
+        # we still shouldn't be done
+        self.assertNoResult(d2)
+
+        # finish the job
+        while waiters:
+            waiters.pop().callback(0)
+
+        # check everything got done
+        self.assertEqual(started, 5)
+        self.assertCountEqual(processed, [1, 2, 3, 4, 5])
+        self.successResultOf(d2)
+
+    def test_preserves_stacktraces(self):
+        """Test that the stacktrace from an exception thrown in the callback is preserved"""
+        d1 = Deferred()
+
+        async def callback(v):
+            # alas, this doesn't work at all without an await here
+            await d1
+            raise _TestException("bah")
+
+        async def caller():
+            try:
+                await concurrently_execute(callback, [1], 2)
+            except _TestException as e:
+                tb = traceback.extract_tb(e.__traceback__)
+                # we expect to see "caller", "concurrently_execute" and "callback".
+                self.assertEqual(tb[0].name, "caller")
+                self.assertEqual(tb[1].name, "concurrently_execute")
+                self.assertEqual(tb[-1].name, "callback")
+            else:
+                self.fail("No exception thrown")
+
+        d2 = ensureDeferred(caller())
+        d1.callback(0)
+        self.successResultOf(d2)
+
+    def test_preserves_stacktraces_on_preformed_failure(self):
+        """Test that the stacktrace on a Failure returned by the callback is preserved"""
+        d1 = Deferred()
+        f = Failure(_TestException("bah"))
+
+        async def callback(v):
+            # alas, this doesn't work at all without an await here
+            await d1
+            await defer.fail(f)
+
+        async def caller():
+            try:
+                await concurrently_execute(callback, [1], 2)
+            except _TestException as e:
+                tb = traceback.extract_tb(e.__traceback__)
+                # we expect to see "caller", "concurrently_execute", "callback",
+                # and some magic from inside ensureDeferred that happens when .fail
+                # is called.
+                self.assertEqual(tb[0].name, "caller")
+                self.assertEqual(tb[1].name, "concurrently_execute")
+                self.assertEqual(tb[-2].name, "callback")
+            else:
+                self.fail("No exception thrown")
+
+        d2 = ensureDeferred(caller())
+        d1.callback(0)
+        self.successResultOf(d2)
+
+
+class StopCancellationTests(TestCase):
+    """Tests for the `stop_cancellation` function."""
+
+    def test_succeed(self):
+        """Test that the new `Deferred` receives the result."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = stop_cancellation(deferred)
+
+        # Success should propagate through.
+        deferred.callback("success")
+        self.assertTrue(wrapper_deferred.called)
+        self.assertEqual("success", self.successResultOf(wrapper_deferred))
+
+    def test_failure(self):
+        """Test that the new `Deferred` receives the `Failure`."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = stop_cancellation(deferred)
+
+        # Failure should propagate through.
+        deferred.errback(ValueError("abc"))
+        self.assertTrue(wrapper_deferred.called)
+        self.failureResultOf(wrapper_deferred, ValueError)
+        self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+    def test_cancellation(self):
+        """Test that cancellation of the new `Deferred` leaves the original running."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = stop_cancellation(deferred)
+
+        # Cancel the new `Deferred`.
+        wrapper_deferred.cancel()
+        self.assertTrue(wrapper_deferred.called)
+        self.failureResultOf(wrapper_deferred, CancelledError)
+        self.assertFalse(
+            deferred.called, "Original `Deferred` was unexpectedly cancelled."
+        )
+
+        # Now make the inner `Deferred` fail.
+        # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+        # in logs.
+        deferred.errback(ValueError("abc"))
+        self.assertIsNone(deferred.result, "`Failure` was not consumed")
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
new file mode 100644
index 0000000000..3c07252252
--- /dev/null
+++ b/tests/util/test_check_dependencies.py
@@ -0,0 +1,95 @@
+from contextlib import contextmanager
+from typing import Generator, Optional
+from unittest.mock import patch
+
+from synapse.util.check_dependencies import (
+    DependencyException,
+    check_requirements,
+    metadata,
+)
+
+from tests.unittest import TestCase
+
+
+class DummyDistribution(metadata.Distribution):
+    def __init__(self, version: str):
+        self._version = version
+
+    @property
+    def version(self):
+        return self._version
+
+    def locate_file(self, path):
+        raise NotImplementedError()
+
+    def read_text(self, filename):
+        raise NotImplementedError()
+
+
+old = DummyDistribution("0.1.2")
+new = DummyDistribution("1.2.3")
+
+# could probably use stdlib TestCase --- no need for twisted here
+
+
+class TestDependencyChecker(TestCase):
+    @contextmanager
+    def mock_installed_package(
+        self, distribution: Optional[DummyDistribution]
+    ) -> Generator[None, None, None]:
+        """Pretend that looking up any distribution yields the given `distribution`."""
+
+        def mock_distribution(name: str):
+            if distribution is None:
+                raise metadata.PackageNotFoundError
+            else:
+                return distribution
+
+        with patch(
+            "synapse.util.check_dependencies.metadata.distribution",
+            mock_distribution,
+        ):
+            yield
+
+    def test_mandatory_dependency(self) -> None:
+        """Complain if a required package is missing or old."""
+        with patch(
+            "synapse.util.check_dependencies.metadata.requires",
+            return_value=["dummypkg >= 1"],
+        ):
+            with self.mock_installed_package(None):
+                self.assertRaises(DependencyException, check_requirements)
+            with self.mock_installed_package(old):
+                self.assertRaises(DependencyException, check_requirements)
+            with self.mock_installed_package(new):
+                # should not raise
+                check_requirements()
+
+    def test_generic_check_of_optional_dependency(self) -> None:
+        """Complain if an optional package is old."""
+        with patch(
+            "synapse.util.check_dependencies.metadata.requires",
+            return_value=["dummypkg >= 1; extra == 'cool-extra'"],
+        ):
+            with self.mock_installed_package(None):
+                # should not raise
+                check_requirements()
+            with self.mock_installed_package(old):
+                self.assertRaises(DependencyException, check_requirements)
+            with self.mock_installed_package(new):
+                # should not raise
+                check_requirements()
+
+    def test_check_for_extra_dependencies(self) -> None:
+        """Complain if a package required for an extra is missing or old."""
+        with patch(
+            "synapse.util.check_dependencies.metadata.requires",
+            return_value=["dummypkg >= 1; extra == 'cool-extra'"],
+        ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}):
+            with self.mock_installed_package(None):
+                self.assertRaises(DependencyException, check_requirements, "cool-extra")
+            with self.mock_installed_package(old):
+                self.assertRaises(DependencyException, check_requirements, "cool-extra")
+            with self.mock_installed_package(new):
+                # should not raise
+                check_requirements()
diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py
index e6e13ba06c..7f60aae5ba 100644
--- a/tests/util/test_expiring_cache.py
+++ b/tests/util/test_expiring_cache.py
@@ -26,8 +26,8 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
         cache = ExpiringCache("test", clock, max_len=1)
 
         cache["key"] = "value"
-        self.assertEquals(cache.get("key"), "value")
-        self.assertEquals(cache["key"], "value")
+        self.assertEqual(cache.get("key"), "value")
+        self.assertEqual(cache["key"], "value")
 
     def test_eviction(self):
         clock = MockClock()
@@ -35,13 +35,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
 
         cache["key"] = "value"
         cache["key2"] = "value2"
-        self.assertEquals(cache.get("key"), "value")
-        self.assertEquals(cache.get("key2"), "value2")
+        self.assertEqual(cache.get("key"), "value")
+        self.assertEqual(cache.get("key2"), "value2")
 
         cache["key3"] = "value3"
-        self.assertEquals(cache.get("key"), None)
-        self.assertEquals(cache.get("key2"), "value2")
-        self.assertEquals(cache.get("key3"), "value3")
+        self.assertEqual(cache.get("key"), None)
+        self.assertEqual(cache.get("key2"), "value2")
+        self.assertEqual(cache.get("key3"), "value3")
 
     def test_iterable_eviction(self):
         clock = MockClock()
@@ -51,15 +51,15 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
         cache["key2"] = [2, 3]
         cache["key3"] = [4, 5]
 
-        self.assertEquals(cache.get("key"), [1])
-        self.assertEquals(cache.get("key2"), [2, 3])
-        self.assertEquals(cache.get("key3"), [4, 5])
+        self.assertEqual(cache.get("key"), [1])
+        self.assertEqual(cache.get("key2"), [2, 3])
+        self.assertEqual(cache.get("key3"), [4, 5])
 
         cache["key4"] = [6, 7]
-        self.assertEquals(cache.get("key"), None)
-        self.assertEquals(cache.get("key2"), None)
-        self.assertEquals(cache.get("key3"), [4, 5])
-        self.assertEquals(cache.get("key4"), [6, 7])
+        self.assertEqual(cache.get("key"), None)
+        self.assertEqual(cache.get("key2"), None)
+        self.assertEqual(cache.get("key3"), [4, 5])
+        self.assertEqual(cache.get("key4"), [6, 7])
 
     def test_time_eviction(self):
         clock = MockClock()
@@ -69,13 +69,13 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
         clock.advance_time(0.5)
         cache["key2"] = 2
 
-        self.assertEquals(cache.get("key"), 1)
-        self.assertEquals(cache.get("key2"), 2)
+        self.assertEqual(cache.get("key"), 1)
+        self.assertEqual(cache.get("key2"), 2)
 
         clock.advance_time(0.9)
-        self.assertEquals(cache.get("key"), None)
-        self.assertEquals(cache.get("key2"), 2)
+        self.assertEqual(cache.get("key"), None)
+        self.assertEqual(cache.get("key2"), 2)
 
         clock.advance_time(1)
-        self.assertEquals(cache.get("key"), None)
-        self.assertEquals(cache.get("key2"), None)
+        self.assertEqual(cache.get("key"), None)
+        self.assertEqual(cache.get("key2"), None)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 621b0f9fcd..2ad321e184 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -17,7 +17,7 @@ from .. import unittest
 
 class LoggingContextTestCase(unittest.TestCase):
     def _check_test_key(self, value):
-        self.assertEquals(current_context().name, value)
+        self.assertEqual(current_context().name, value)
 
     def test_with_context(self):
         with LoggingContext("test"):
diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index 291644eb7d..321fc1776f 100644
--- a/tests/util/test_lrucache.py
+++ b/tests/util/test_lrucache.py
@@ -27,37 +27,37 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
     def test_get_set(self):
         cache = LruCache(1)
         cache["key"] = "value"
-        self.assertEquals(cache.get("key"), "value")
-        self.assertEquals(cache["key"], "value")
+        self.assertEqual(cache.get("key"), "value")
+        self.assertEqual(cache["key"], "value")
 
     def test_eviction(self):
         cache = LruCache(2)
         cache[1] = 1
         cache[2] = 2
 
-        self.assertEquals(cache.get(1), 1)
-        self.assertEquals(cache.get(2), 2)
+        self.assertEqual(cache.get(1), 1)
+        self.assertEqual(cache.get(2), 2)
 
         cache[3] = 3
 
-        self.assertEquals(cache.get(1), None)
-        self.assertEquals(cache.get(2), 2)
-        self.assertEquals(cache.get(3), 3)
+        self.assertEqual(cache.get(1), None)
+        self.assertEqual(cache.get(2), 2)
+        self.assertEqual(cache.get(3), 3)
 
     def test_setdefault(self):
         cache = LruCache(1)
-        self.assertEquals(cache.setdefault("key", 1), 1)
-        self.assertEquals(cache.get("key"), 1)
-        self.assertEquals(cache.setdefault("key", 2), 1)
-        self.assertEquals(cache.get("key"), 1)
+        self.assertEqual(cache.setdefault("key", 1), 1)
+        self.assertEqual(cache.get("key"), 1)
+        self.assertEqual(cache.setdefault("key", 2), 1)
+        self.assertEqual(cache.get("key"), 1)
         cache["key"] = 2  # Make sure overriding works.
-        self.assertEquals(cache.get("key"), 2)
+        self.assertEqual(cache.get("key"), 2)
 
     def test_pop(self):
         cache = LruCache(1)
         cache["key"] = 1
-        self.assertEquals(cache.pop("key"), 1)
-        self.assertEquals(cache.pop("key"), None)
+        self.assertEqual(cache.pop("key"), 1)
+        self.assertEqual(cache.pop("key"), None)
 
     def test_del_multi(self):
         cache = LruCache(4, cache_type=TreeCache)
@@ -66,23 +66,23 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
         cache[("vehicles", "car")] = "vroom"
         cache[("vehicles", "train")] = "chuff"
 
-        self.assertEquals(len(cache), 4)
+        self.assertEqual(len(cache), 4)
 
-        self.assertEquals(cache.get(("animal", "cat")), "mew")
-        self.assertEquals(cache.get(("vehicles", "car")), "vroom")
+        self.assertEqual(cache.get(("animal", "cat")), "mew")
+        self.assertEqual(cache.get(("vehicles", "car")), "vroom")
         cache.del_multi(("animal",))
-        self.assertEquals(len(cache), 2)
-        self.assertEquals(cache.get(("animal", "cat")), None)
-        self.assertEquals(cache.get(("animal", "dog")), None)
-        self.assertEquals(cache.get(("vehicles", "car")), "vroom")
-        self.assertEquals(cache.get(("vehicles", "train")), "chuff")
+        self.assertEqual(len(cache), 2)
+        self.assertEqual(cache.get(("animal", "cat")), None)
+        self.assertEqual(cache.get(("animal", "dog")), None)
+        self.assertEqual(cache.get(("vehicles", "car")), "vroom")
+        self.assertEqual(cache.get(("vehicles", "train")), "chuff")
         # Man from del_multi say "Yes".
 
     def test_clear(self):
         cache = LruCache(1)
         cache["key"] = 1
         cache.clear()
-        self.assertEquals(len(cache), 0)
+        self.assertEqual(len(cache), 0)
 
     @override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
     def test_special_size(self):
@@ -105,10 +105,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         self.assertFalse(m.called)
 
         cache.set("key", "value2")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
         cache.set("key", "value")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
     def test_multi_get(self):
         m = Mock()
@@ -124,10 +124,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         self.assertFalse(m.called)
 
         cache.set("key", "value2")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
         cache.set("key", "value")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
     def test_set(self):
         m = Mock()
@@ -140,10 +140,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         self.assertFalse(m.called)
 
         cache.set("key", "value2")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
         cache.set("key", "value")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
     def test_pop(self):
         m = Mock()
@@ -153,13 +153,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         self.assertFalse(m.called)
 
         cache.pop("key")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
         cache.set("key", "value")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
         cache.pop("key")
-        self.assertEquals(m.call_count, 1)
+        self.assertEqual(m.call_count, 1)
 
     def test_del_multi(self):
         m1 = Mock()
@@ -173,17 +173,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         cache.set(("b", "1"), "value", callbacks=[m3])
         cache.set(("b", "2"), "value", callbacks=[m4])
 
-        self.assertEquals(m1.call_count, 0)
-        self.assertEquals(m2.call_count, 0)
-        self.assertEquals(m3.call_count, 0)
-        self.assertEquals(m4.call_count, 0)
+        self.assertEqual(m1.call_count, 0)
+        self.assertEqual(m2.call_count, 0)
+        self.assertEqual(m3.call_count, 0)
+        self.assertEqual(m4.call_count, 0)
 
         cache.del_multi(("a",))
 
-        self.assertEquals(m1.call_count, 1)
-        self.assertEquals(m2.call_count, 1)
-        self.assertEquals(m3.call_count, 0)
-        self.assertEquals(m4.call_count, 0)
+        self.assertEqual(m1.call_count, 1)
+        self.assertEqual(m2.call_count, 1)
+        self.assertEqual(m3.call_count, 0)
+        self.assertEqual(m4.call_count, 0)
 
     def test_clear(self):
         m1 = Mock()
@@ -193,13 +193,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         cache.set("key1", "value", callbacks=[m1])
         cache.set("key2", "value", callbacks=[m2])
 
-        self.assertEquals(m1.call_count, 0)
-        self.assertEquals(m2.call_count, 0)
+        self.assertEqual(m1.call_count, 0)
+        self.assertEqual(m2.call_count, 0)
 
         cache.clear()
 
-        self.assertEquals(m1.call_count, 1)
-        self.assertEquals(m2.call_count, 1)
+        self.assertEqual(m1.call_count, 1)
+        self.assertEqual(m2.call_count, 1)
 
     def test_eviction(self):
         m1 = Mock(name="m1")
@@ -210,33 +210,33 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
         cache.set("key1", "value", callbacks=[m1])
         cache.set("key2", "value", callbacks=[m2])
 
-        self.assertEquals(m1.call_count, 0)
-        self.assertEquals(m2.call_count, 0)
-        self.assertEquals(m3.call_count, 0)
+        self.assertEqual(m1.call_count, 0)
+        self.assertEqual(m2.call_count, 0)
+        self.assertEqual(m3.call_count, 0)
 
         cache.set("key3", "value", callbacks=[m3])
 
-        self.assertEquals(m1.call_count, 1)
-        self.assertEquals(m2.call_count, 0)
-        self.assertEquals(m3.call_count, 0)
+        self.assertEqual(m1.call_count, 1)
+        self.assertEqual(m2.call_count, 0)
+        self.assertEqual(m3.call_count, 0)
 
         cache.set("key3", "value")
 
-        self.assertEquals(m1.call_count, 1)
-        self.assertEquals(m2.call_count, 0)
-        self.assertEquals(m3.call_count, 0)
+        self.assertEqual(m1.call_count, 1)
+        self.assertEqual(m2.call_count, 0)
+        self.assertEqual(m3.call_count, 0)
 
         cache.get("key2")
 
-        self.assertEquals(m1.call_count, 1)
-        self.assertEquals(m2.call_count, 0)
-        self.assertEquals(m3.call_count, 0)
+        self.assertEqual(m1.call_count, 1)
+        self.assertEqual(m2.call_count, 0)
+        self.assertEqual(m3.call_count, 0)
 
         cache.set("key1", "value", callbacks=[m1])
 
-        self.assertEquals(m1.call_count, 1)
-        self.assertEquals(m2.call_count, 0)
-        self.assertEquals(m3.call_count, 1)
+        self.assertEqual(m1.call_count, 1)
+        self.assertEqual(m2.call_count, 0)
+        self.assertEqual(m3.call_count, 1)
 
 
 class LruCacheSizedTestCase(unittest.HomeserverTestCase):
@@ -247,20 +247,20 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
         cache["key3"] = [3]
         cache["key4"] = [4]
 
-        self.assertEquals(cache["key1"], [0])
-        self.assertEquals(cache["key2"], [1, 2])
-        self.assertEquals(cache["key3"], [3])
-        self.assertEquals(cache["key4"], [4])
-        self.assertEquals(len(cache), 5)
+        self.assertEqual(cache["key1"], [0])
+        self.assertEqual(cache["key2"], [1, 2])
+        self.assertEqual(cache["key3"], [3])
+        self.assertEqual(cache["key4"], [4])
+        self.assertEqual(len(cache), 5)
 
         cache["key5"] = [5, 6]
 
-        self.assertEquals(len(cache), 4)
-        self.assertEquals(cache.get("key1"), None)
-        self.assertEquals(cache.get("key2"), None)
-        self.assertEquals(cache["key3"], [3])
-        self.assertEquals(cache["key4"], [4])
-        self.assertEquals(cache["key5"], [5, 6])
+        self.assertEqual(len(cache), 4)
+        self.assertEqual(cache.get("key1"), None)
+        self.assertEqual(cache.get("key2"), None)
+        self.assertEqual(cache["key3"], [3])
+        self.assertEqual(cache["key4"], [4])
+        self.assertEqual(cache["key5"], [5, 6])
 
     def test_zero_size_drop_from_cache(self) -> None:
         """Test that `drop_from_cache` works correctly with 0-sized entries."""
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9e1bebdc83..26cb71c640 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -24,7 +24,7 @@ from tests.unittest import HomeserverTestCase
 class RetryLimiterTestCase(HomeserverTestCase):
     def test_new_destination(self):
         """A happy-path case with a new destination and a successful operation"""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
         limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         # advance the clock a bit before making the request
@@ -38,7 +38,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
 
     def test_limiter(self):
         """General test case which walks through the process of a failing request"""
-        store = self.hs.get_datastore()
+        store = self.hs.get_datastores().main
 
         limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index a10071c70f..0774625b85 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 from twisted.internet import defer
+from twisted.internet.defer import Deferred
 
 from synapse.util.async_helpers import ReadWriteLock
 
@@ -83,3 +84,32 @@ class ReadWriteLockTestCase(unittest.TestCase):
         self.assertTrue(d.called)
         with d.result:
             pass
+
+    def test_lock_handoff_to_nonblocking_writer(self):
+        """Test a writer handing the lock to another writer that completes instantly."""
+        rwlock = ReadWriteLock()
+        key = "key"
+
+        unblock: "Deferred[None]" = Deferred()
+
+        async def blocking_write():
+            with await rwlock.write(key):
+                await unblock
+
+        async def nonblocking_write():
+            with await rwlock.write(key):
+                pass
+
+        d1 = defer.ensureDeferred(blocking_write())
+        d2 = defer.ensureDeferred(nonblocking_write())
+        self.assertFalse(d1.called)
+        self.assertFalse(d2.called)
+
+        # Unblock the first writer. The second writer will complete without blocking.
+        unblock.callback(None)
+        self.assertTrue(d1.called)
+        self.assertTrue(d2.called)
+
+        # The `ReadWriteLock` should operate as normal.
+        d3 = defer.ensureDeferred(nonblocking_write())
+        self.assertTrue(d3.called)
diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py
index 6066372053..567cb18468 100644
--- a/tests/util/test_treecache.py
+++ b/tests/util/test_treecache.py
@@ -23,61 +23,61 @@ class TreeCacheTestCase(unittest.TestCase):
         cache = TreeCache()
         cache[("a",)] = "A"
         cache[("b",)] = "B"
-        self.assertEquals(cache.get(("a",)), "A")
-        self.assertEquals(cache.get(("b",)), "B")
-        self.assertEquals(len(cache), 2)
+        self.assertEqual(cache.get(("a",)), "A")
+        self.assertEqual(cache.get(("b",)), "B")
+        self.assertEqual(len(cache), 2)
 
     def test_pop_onelevel(self):
         cache = TreeCache()
         cache[("a",)] = "A"
         cache[("b",)] = "B"
-        self.assertEquals(cache.pop(("a",)), "A")
-        self.assertEquals(cache.pop(("a",)), None)
-        self.assertEquals(cache.get(("b",)), "B")
-        self.assertEquals(len(cache), 1)
+        self.assertEqual(cache.pop(("a",)), "A")
+        self.assertEqual(cache.pop(("a",)), None)
+        self.assertEqual(cache.get(("b",)), "B")
+        self.assertEqual(len(cache), 1)
 
     def test_get_set_twolevel(self):
         cache = TreeCache()
         cache[("a", "a")] = "AA"
         cache[("a", "b")] = "AB"
         cache[("b", "a")] = "BA"
-        self.assertEquals(cache.get(("a", "a")), "AA")
-        self.assertEquals(cache.get(("a", "b")), "AB")
-        self.assertEquals(cache.get(("b", "a")), "BA")
-        self.assertEquals(len(cache), 3)
+        self.assertEqual(cache.get(("a", "a")), "AA")
+        self.assertEqual(cache.get(("a", "b")), "AB")
+        self.assertEqual(cache.get(("b", "a")), "BA")
+        self.assertEqual(len(cache), 3)
 
     def test_pop_twolevel(self):
         cache = TreeCache()
         cache[("a", "a")] = "AA"
         cache[("a", "b")] = "AB"
         cache[("b", "a")] = "BA"
-        self.assertEquals(cache.pop(("a", "a")), "AA")
-        self.assertEquals(cache.get(("a", "a")), None)
-        self.assertEquals(cache.get(("a", "b")), "AB")
-        self.assertEquals(cache.pop(("b", "a")), "BA")
-        self.assertEquals(cache.pop(("b", "a")), None)
-        self.assertEquals(len(cache), 1)
+        self.assertEqual(cache.pop(("a", "a")), "AA")
+        self.assertEqual(cache.get(("a", "a")), None)
+        self.assertEqual(cache.get(("a", "b")), "AB")
+        self.assertEqual(cache.pop(("b", "a")), "BA")
+        self.assertEqual(cache.pop(("b", "a")), None)
+        self.assertEqual(len(cache), 1)
 
     def test_pop_mixedlevel(self):
         cache = TreeCache()
         cache[("a", "a")] = "AA"
         cache[("a", "b")] = "AB"
         cache[("b", "a")] = "BA"
-        self.assertEquals(cache.get(("a", "a")), "AA")
+        self.assertEqual(cache.get(("a", "a")), "AA")
         popped = cache.pop(("a",))
-        self.assertEquals(cache.get(("a", "a")), None)
-        self.assertEquals(cache.get(("a", "b")), None)
-        self.assertEquals(cache.get(("b", "a")), "BA")
-        self.assertEquals(len(cache), 1)
+        self.assertEqual(cache.get(("a", "a")), None)
+        self.assertEqual(cache.get(("a", "b")), None)
+        self.assertEqual(cache.get(("b", "a")), "BA")
+        self.assertEqual(len(cache), 1)
 
-        self.assertEquals({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
+        self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
 
     def test_clear(self):
         cache = TreeCache()
         cache[("a",)] = "A"
         cache[("b",)] = "B"
         cache.clear()
-        self.assertEquals(len(cache), 0)
+        self.assertEqual(len(cache), 0)
 
     def test_contains(self):
         cache = TreeCache()
diff --git a/tests/utils.py b/tests/utils.py
index c06fc320f3..ef99c72e0b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -367,7 +367,7 @@ async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room"""
 
     persistence_store = hs.get_storage().persistence
-    store = hs.get_datastore()
+    store = hs.get_datastores().main
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
diff --git a/tox.ini b/tox.ini
index 436ecf7552..04b972e2c5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -168,16 +168,6 @@ commands =
 extras = lint
 commands = isort -c --df {[base]lint_targets}
 
-[testenv:check-newsfragment]
-skip_install = true
-usedevelop = false
-deps = towncrier>=18.6.0rc1
-commands =
-   python -m towncrier.check --compare-with=origin/develop
-
-[testenv:check-sampleconfig]
-commands = {toxinidir}/scripts-dev/generate_sample_config --check
-
 [testenv:combine]
 skip_install = true
 usedevelop = false