summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/tests.yml12
-rw-r--r--CHANGES.md133
-rw-r--r--contrib/docker/docker-compose.yml1
-rw-r--r--debian/changelog24
-rw-r--r--docker/Dockerfile-dhvirtualenv9
-rw-r--r--docker/README.md4
-rw-r--r--docs/SUMMARY.md1
-rw-r--r--docs/admin_api/user_admin_api.md75
-rw-r--r--docs/deprecation_policy.md4
-rw-r--r--docs/reverse_proxy.md2
-rw-r--r--docs/sample_config.yaml3
-rw-r--r--docs/setup/installation.md2
-rw-r--r--docs/sso_mapping_providers.md24
-rw-r--r--docs/turn-howto.md67
-rw-r--r--docs/upgrade.md11
-rw-r--r--docs/usage/configuration/user_authentication/refresh_tokens.md139
-rw-r--r--mypy.ini48
-rw-r--r--pyproject.toml2
-rwxr-xr-xscripts-dev/build_debian_packages1
-rwxr-xr-xscripts-dev/check-newsfragment4
-rwxr-xr-xsetup.py8
-rw-r--r--stubs/txredisapi.pyi9
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py135
-rw-r--r--synapse/api/constants.py4
-rw-r--r--synapse/api/filtering.py3
-rw-r--r--synapse/app/homeserver.py9
-rw-r--r--synapse/appservice/__init__.py101
-rw-r--r--synapse/appservice/api.py48
-rw-r--r--synapse/appservice/scheduler.py74
-rw-r--r--synapse/config/api.py2
-rw-r--r--synapse/config/appservice.py3
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/config/key.py36
-rw-r--r--synapse/config/metrics.py6
-rw-r--r--synapse/config/modules.py2
-rw-r--r--synapse/config/repository.py34
-rw-r--r--synapse/config/room_directory.py3
-rw-r--r--synapse/config/server.py2
-rw-r--r--synapse/config/tls.py5
-rw-r--r--synapse/events/utils.py13
-rw-r--r--synapse/federation/federation_base.py12
-rw-r--r--synapse/federation/federation_client.py38
-rw-r--r--synapse/federation/federation_server.py12
-rw-r--r--synapse/federation/send_queue.py47
-rw-r--r--synapse/federation/transport/server/_base.py39
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/handlers/directory.py10
-rw-r--r--synapse/handlers/e2e_keys.py12
-rw-r--r--synapse/handlers/e2e_room_keys.py15
-rw-r--r--synapse/handlers/events.py9
-rw-r--r--synapse/handlers/federation.py19
-rw-r--r--synapse/handlers/federation_event.py21
-rw-r--r--synapse/handlers/initial_sync.py51
-rw-r--r--synapse/handlers/message.py37
-rw-r--r--synapse/handlers/pagination.py5
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/receipts.py6
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_list.py22
-rw-r--r--synapse/handlers/room_member.py9
-rw-r--r--synapse/handlers/stats.py11
-rw-r--r--synapse/handlers/sync.py74
-rw-r--r--synapse/handlers/typing.py14
-rw-r--r--synapse/handlers/user_directory.py18
-rw-r--r--synapse/http/__init__.py6
-rw-r--r--synapse/http/additional_resource.py12
-rw-r--r--synapse/http/client.py17
-rw-r--r--synapse/http/federation/matrix_federation_agent.py7
-rw-r--r--synapse/http/matrixfederationclient.py3
-rw-r--r--synapse/http/server.py125
-rw-r--r--synapse/http/servlet.py50
-rw-r--r--synapse/http/site.py36
-rw-r--r--synapse/logging/context.py150
-rw-r--r--synapse/logging/opentracing.py91
-rw-r--r--synapse/logging/scopecontextmanager.py2
-rw-r--r--synapse/notifier.py25
-rw-r--r--synapse/push/emailpusher.py18
-rw-r--r--synapse/push/httppusher.py12
-rw-r--r--synapse/push/mailer.py40
-rw-r--r--synapse/push/push_rule_evaluator.py7
-rw-r--r--synapse/push/push_tools.py7
-rw-r--r--synapse/push/pusherpool.py5
-rw-r--r--synapse/python_dependencies.py4
-rw-r--r--synapse/replication/slave/storage/_base.py9
-rw-r--r--synapse/replication/slave/storage/client_ips.py9
-rw-r--r--synapse/replication/slave/storage/devices.py9
-rw-r--r--synapse/replication/slave/storage/events.py18
-rw-r--r--synapse/replication/slave/storage/filtering.py9
-rw-r--r--synapse/replication/slave/storage/groups.py9
-rw-r--r--synapse/replication/tcp/streams/_base.py129
-rw-r--r--synapse/replication/tcp/streams/federation.py15
-rw-r--r--synapse/rest/admin/__init__.py6
-rw-r--r--synapse/rest/admin/background_updates.py16
-rw-r--r--synapse/rest/admin/devices.py22
-rw-r--r--synapse/rest/admin/event_reports.py2
-rw-r--r--synapse/rest/admin/federation.py2
-rw-r--r--synapse/rest/admin/groups.py2
-rw-r--r--synapse/rest/admin/media.py60
-rw-r--r--synapse/rest/admin/registration_tokens.py3
-rw-r--r--synapse/rest/admin/rooms.py82
-rw-r--r--synapse/rest/admin/server_notice_servlet.py4
-rw-r--r--synapse/rest/admin/statistics.py22
-rw-r--r--synapse/rest/admin/username_available.py2
-rw-r--r--synapse/rest/admin/users.py81
-rw-r--r--synapse/rest/client/devices.py6
-rw-r--r--synapse/rest/client/notifications.py23
-rw-r--r--synapse/rest/client/read_marker.py6
-rw-r--r--synapse/rest/client/receipts.py4
-rw-r--r--synapse/rest/client/relations.py11
-rw-r--r--synapse/rest/client/room.py12
-rw-r--r--synapse/rest/client/sync.py11
-rw-r--r--synapse/rest/client/versions.py4
-rw-r--r--synapse/rest/key/v2/local_key_resource.py4
-rw-r--r--synapse/rest/media/v1/media_repository.py19
-rw-r--r--synapse/rest/media/v1/oembed.py5
-rw-r--r--synapse/rest/media/v1/preview_html.py397
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py383
-rw-r--r--synapse/state/__init__.py5
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/database.py137
-rw-r--r--synapse/storage/databases/main/__init__.py17
-rw-r--r--synapse/storage/databases/main/account_data.py93
-rw-r--r--synapse/storage/databases/main/appservice.py10
-rw-r--r--synapse/storage/databases/main/cache.py9
-rw-r--r--synapse/storage/databases/main/censor_events.py13
-rw-r--r--synapse/storage/databases/main/client_ips.py22
-rw-r--r--synapse/storage/databases/main/deviceinbox.py11
-rw-r--r--synapse/storage/databases/main/devices.py56
-rw-r--r--synapse/storage/databases/main/directory.py10
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py237
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py211
-rw-r--r--synapse/storage/databases/main/event_federation.py32
-rw-r--r--synapse/storage/databases/main/event_push_actions.py273
-rw-r--r--synapse/storage/databases/main/events.py148
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py77
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/filtering.py4
-rw-r--r--synapse/storage/databases/main/group_server.py10
-rw-r--r--synapse/storage/databases/main/lock.py14
-rw-r--r--synapse/storage/databases/main/media_repository.py3
-rw-r--r--synapse/storage/databases/main/metrics.py27
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py27
-rw-r--r--synapse/storage/databases/main/presence.py11
-rw-r--r--synapse/storage/databases/main/push_rule.py9
-rw-r--r--synapse/storage/databases/main/pusher.py29
-rw-r--r--synapse/storage/databases/main/receipts.py112
-rw-r--r--synapse/storage/databases/main/registration.py24
-rw-r--r--synapse/storage/databases/main/relations.py42
-rw-r--r--synapse/storage/databases/main/room.py226
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/search.py36
-rw-r--r--synapse/storage/databases/main/state.py47
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py103
-rw-r--r--synapse/storage/databases/main/stream.py54
-rw-r--r--synapse/storage/databases/main/tags.py22
-rw-r--r--synapse/storage/databases/main/transactions.py62
-rw-r--r--synapse/storage/databases/main/ui_auth.py15
-rw-r--r--synapse/storage/databases/main/user_directory.py11
-rw-r--r--synapse/storage/schema/__init__.py5
-rw-r--r--synapse/storage/util/id_generators.py6
-rw-r--r--synapse/types.py28
-rw-r--r--synapse/util/__init__.py59
-rw-r--r--synapse/util/async_helpers.py79
-rw-r--r--synapse/util/caches/cached_call.py1
-rw-r--r--synapse/util/caches/lrucache.py1
-rw-r--r--synapse/util/caches/response_cache.py127
-rw-r--r--synapse/util/file_consumer.py1
-rw-r--r--tests/api/test_auth.py64
-rw-r--r--tests/appservice/test_appservice.py11
-rw-r--r--tests/federation/test_federation_sender.py5
-rw-r--r--tests/federation/transport/test_knocking.py9
-rw-r--r--tests/handlers/test_e2e_keys.py30
-rw-r--r--tests/handlers/test_federation.py4
-rw-r--r--tests/handlers/test_message.py103
-rw-r--r--tests/replication/slave/storage/test_events.py7
-rw-r--r--tests/replication/test_federation_ack.py6
-rw-r--r--tests/rest/admin/test_background_updates.py3
-rw-r--r--tests/rest/admin/test_federation.py37
-rw-r--r--tests/rest/admin/test_media.py6
-rw-r--r--tests/rest/admin/test_registration_tokens.py25
-rw-r--r--tests/rest/admin/test_room.py62
-rw-r--r--tests/rest/admin/test_user.py105
-rw-r--r--tests/rest/client/test_auth.py134
-rw-r--r--tests/rest/client/test_relations.py125
-rw-r--r--tests/rest/client/test_room_batch.py180
-rw-r--r--tests/rest/media/v1/test_url_preview.py1
-rw-r--r--tests/server.py199
-rw-r--r--tests/storage/test_account_data.py4
-rw-r--r--tests/storage/test_background_update.py17
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/storage/test_e2e_room_keys.py4
-rw-r--r--tests/storage/test_event_federation.py2
-rw-r--r--tests/storage/test_event_push_actions.py12
-rw-r--r--tests/storage/test_roommember.py2
-rw-r--r--tests/test_preview.py46
-rw-r--r--tests/unittest.py11
-rw-r--r--tests/util/caches/test_response_cache.py45
-rw-r--r--tests/util/test_glob_to_regex.py59
-rw-r--r--tests/util/test_logcontext.py35
-rw-r--r--tests/utils.py175
-rw-r--r--tox.ini2
205 files changed, 4965 insertions, 2756 deletions
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 21c9ee7823..cb72e1a233 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -76,7 +76,7 @@ jobs:
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
+        python-version: ["3.7", "3.8", "3.9", "3.10"]
         database: ["sqlite"]
         toxenv: ["py"]
         include:
@@ -85,9 +85,9 @@ jobs:
             toxenv: "py-noextras"
 
           # Oldest Python with PostgreSQL
-          - python-version: "3.6"
+          - python-version: "3.7"
             database: "postgres"
-            postgres-version: "9.6"
+            postgres-version: "10"
             toxenv: "py"
 
           # Newest Python with newest PostgreSQL
@@ -167,7 +167,7 @@ jobs:
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: ["pypy-3.6"]
+        python-version: ["pypy-3.7"]
 
     steps:
       - uses: actions/checkout@v2
@@ -291,8 +291,8 @@ jobs:
     strategy:
       matrix:
         include:
-          - python-version: "3.6"
-            postgres-version: "9.6"
+          - python-version: "3.7"
+            postgres-version: "10"
 
           - python-version: "3.10"
             postgres-version: "14"
diff --git a/CHANGES.md b/CHANGES.md
index 72e8d64cf7..e8cd60e9e5 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,6 +1,134 @@
-Synapse 1.49.0rc1 (2021-12-07)
+Synapse 1.50.0rc1 (2022-01-05)
 ==============================
 
+Please note that we now only support Python 3.7+ and PostgreSQL 10+ (if applicable), because Python 3.6 and PostgreSQL 9.6 have reached end-of-life.
+
+
+Features
+--------
+
+- Allow guests to send state events per [MSC3419](https://github.com/matrix-org/matrix-doc/pull/3419). ([\#11378](https://github.com/matrix-org/synapse/issues/11378))
+- Add experimental support for part of [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): allowing application services to masquerade as specific devices. ([\#11538](https://github.com/matrix-org/synapse/issues/11538))
+- Add admin API to get users' account data. ([\#11664](https://github.com/matrix-org/synapse/issues/11664))
+- Include the room topic in the stripped state included with invites and knocking. ([\#11666](https://github.com/matrix-org/synapse/issues/11666))
+- Send and handle cross-signing messages using the stable prefix. ([\#10520](https://github.com/matrix-org/synapse/issues/10520))
+- Support unprefixed versions of fallback key property names. ([\#11541](https://github.com/matrix-org/synapse/issues/11541))
+
+
+Bugfixes
+--------
+
+- Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event. ([\#11516](https://github.com/matrix-org/synapse/issues/11516))
+- Fix a long-standing bug which could cause `AssertionError`s to be written to the log when Synapse was restarted after purging events from the database. ([\#11536](https://github.com/matrix-org/synapse/issues/11536), [\#11642](https://github.com/matrix-org/synapse/issues/11642))
+- Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created. ([\#11547](https://github.com/matrix-org/synapse/issues/11547))
+- Fix a long-standing bug where responses included bundled aggregations when they should not, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11592](https://github.com/matrix-org/synapse/issues/11592), [\#11623](https://github.com/matrix-org/synapse/issues/11623))
+- Fix a long-standing bug that some unknown endpoints would return HTML error pages instead of JSON `M_UNRECOGNIZED` errors. ([\#11602](https://github.com/matrix-org/synapse/issues/11602))
+- Fix a bug introduced in Synapse 1.19.3 which could sometimes cause `AssertionError`s when backfilling rooms over federation. ([\#11632](https://github.com/matrix-org/synapse/issues/11632))
+- Fix a bug in `SimpleHttpClient.get_json` that results in the `Accept` request header being absent. ([\#11677](https://github.com/matrix-org/synapse/issues/11677))
+
+
+Improved Documentation
+----------------------
+
+- Update Synapse install command for FreeBSD as the package is now prefixed with `py38`. Contributed by @itchychips. ([\#11267](https://github.com/matrix-org/synapse/issues/11267))
+- Document the usage of refresh tokens. ([\#11427](https://github.com/matrix-org/synapse/issues/11427))
+- Add details for how to configure a TURN server when behind a NAT. Contibuted by @AndrewFerr. ([\#11553](https://github.com/matrix-org/synapse/issues/11553))
+- Add references for using Postgres to the Docker documentation. ([\#11640](https://github.com/matrix-org/synapse/issues/11640))
+- Fix the documentation link in newly-generated configuration files. ([\#11678](https://github.com/matrix-org/synapse/issues/11678))
+- Correct the documentation for `nginx` to use a case-sensitive url pattern. Fixes an error introduced in v1.21.0. ([\#11680](https://github.com/matrix-org/synapse/issues/11680))
+- Clarify SSO mapping provider documentation by writing `def` or `async def` before the names of methods, as appropriate. ([\#11681](https://github.com/matrix-org/synapse/issues/11681))
+
+
+Deprecations and Removals
+-------------------------
+
+- Replace `mock` package by its standard library version. ([\#11588](https://github.com/matrix-org/synapse/issues/11588))
+
+
+Internal Changes
+----------------
+
+- Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#11243](https://github.com/matrix-org/synapse/issues/11243))
+- A test helper (`wait_for_background_updates`) no longer depends on classes defining a `store` property. ([\#11331](https://github.com/matrix-org/synapse/issues/11331))
+- Add type hints to `synapse.appservice`. ([\#11360](https://github.com/matrix-org/synapse/issues/11360))
+- Add missing type hints to `synapse.config` module. ([\#11480](https://github.com/matrix-org/synapse/issues/11480))
+- Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint. ([\#11487](https://github.com/matrix-org/synapse/issues/11487))
+- Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`. ([\#11503](https://github.com/matrix-org/synapse/issues/11503))
+- Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. ([\#11505](https://github.com/matrix-org/synapse/issues/11505), [\#11687](https://github.com/matrix-org/synapse/issues/11687))
+- Use `HTTPStatus` constants in place of literals in `tests.rest.client.test_auth`. ([\#11520](https://github.com/matrix-org/synapse/issues/11520))
+- Add a receipt types constant for `m.read`. ([\#11531](https://github.com/matrix-org/synapse/issues/11531))
+- Clean up `synapse.rest.admin`. ([\#11535](https://github.com/matrix-org/synapse/issues/11535))
+- Add missing `errcode` to `parse_string` and `parse_boolean`. ([\#11542](https://github.com/matrix-org/synapse/issues/11542))
+- Use `HTTPStatus` constants in place of literals in `synapse.http`. ([\#11543](https://github.com/matrix-org/synapse/issues/11543))
+- Add missing type hints to storage classes. ([\#11546](https://github.com/matrix-org/synapse/issues/11546), [\#11549](https://github.com/matrix-org/synapse/issues/11549), [\#11551](https://github.com/matrix-org/synapse/issues/11551), [\#11555](https://github.com/matrix-org/synapse/issues/11555), [\#11575](https://github.com/matrix-org/synapse/issues/11575), [\#11589](https://github.com/matrix-org/synapse/issues/11589), [\#11594](https://github.com/matrix-org/synapse/issues/11594), [\#11652](https://github.com/matrix-org/synapse/issues/11652), [\#11653](https://github.com/matrix-org/synapse/issues/11653), [\#11654](https://github.com/matrix-org/synapse/issues/11654), [\#11657](https://github.com/matrix-org/synapse/issues/11657))
+- Fix an inaccurate and misleading comment in the `/sync` code. ([\#11550](https://github.com/matrix-org/synapse/issues/11550))
+- Add missing type hints to `synapse.logging.context`. ([\#11556](https://github.com/matrix-org/synapse/issues/11556))
+- Stop populating unused database column `state_events.prev_state`. ([\#11558](https://github.com/matrix-org/synapse/issues/11558))
+- Minor efficiency improvements in event persistence. ([\#11560](https://github.com/matrix-org/synapse/issues/11560))
+- Add some safety checks that storage functions are used correctly. ([\#11564](https://github.com/matrix-org/synapse/issues/11564), [\#11580](https://github.com/matrix-org/synapse/issues/11580))
+- Make `get_device` return `None` if the device doesn't exist rather than raising an exception. ([\#11565](https://github.com/matrix-org/synapse/issues/11565))
+- Split the HTML parsing code from the URL preview resource code. ([\#11566](https://github.com/matrix-org/synapse/issues/11566))
+- Remove redundant `COALESCE()`s around `COUNT()`s in database queries. ([\#11570](https://github.com/matrix-org/synapse/issues/11570))
+- Add missing type hints to `synapse.http`. ([\#11571](https://github.com/matrix-org/synapse/issues/11571))
+- Add [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) and [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) to `/versions` -> `unstable_features` to detect server support. ([\#11582](https://github.com/matrix-org/synapse/issues/11582))
+- Add type hints to `synapse/tests/rest/admin`. ([\#11590](https://github.com/matrix-org/synapse/issues/11590))
+- Drop end-of-life Python 3.6 and Postgres 9.6 from CI. ([\#11595](https://github.com/matrix-org/synapse/issues/11595))
+- Update black version and run it on all the files. ([\#11596](https://github.com/matrix-org/synapse/issues/11596))
+- Add opentracing type stubs and fix associated mypy errors. ([\#11603](https://github.com/matrix-org/synapse/issues/11603), [\#11622](https://github.com/matrix-org/synapse/issues/11622))
+- Improve OpenTracing support for requests which use a `ResponseCache`. ([\#11607](https://github.com/matrix-org/synapse/issues/11607))
+- Improve OpenTracing support for incoming HTTP requests. ([\#11618](https://github.com/matrix-org/synapse/issues/11618))
+- A number of improvements to opentracing support. ([\#11619](https://github.com/matrix-org/synapse/issues/11619))
+- Drop support for Python 3.6 and Ubuntu 18.04. ([\#11633](https://github.com/matrix-org/synapse/issues/11633))
+- Refactor the way that the `outlier` flag is set on events received over federation. ([\#11634](https://github.com/matrix-org/synapse/issues/11634))
+- Improve the error messages from  `get_create_event_for_room`. ([\#11638](https://github.com/matrix-org/synapse/issues/11638))
+- Remove redundant `get_current_events_token` method. ([\#11643](https://github.com/matrix-org/synapse/issues/11643))
+- Convert `namedtuples` to `attrs`. ([\#11665](https://github.com/matrix-org/synapse/issues/11665), [\#11574](https://github.com/matrix-org/synapse/issues/11574))
+
+
+Synapse 1.49.2 (2021-12-21)
+===========================
+
+This release fixes a regression introduced in Synapse 1.49.0 which could cause `/sync` requests to take significantly longer. This would particularly affect "initial" syncs for users participating in a large number of rooms, and in extreme cases, could make it impossible for such users to log in on a new client.
+
+**Note:** in line with our [deprecation policy](https://matrix-org.github.io/synapse/latest/deprecation_policy.html) for platform dependencies, this will be the last release to support Python 3.6 and PostgreSQL 9.6, both of which have now reached upstream end-of-life. Synapse will require Python 3.7+ and PostgreSQL 10+.
+
+**Note:** We will also stop producing packages for Ubuntu 18.04 (Bionic Beaver) after this release, as it uses Python 3.6.
+
+Bugfixes
+--------
+
+- Fix a performance regression in `/sync` handling, introduced in 1.49.0. ([\#11583](https://github.com/matrix-org/synapse/issues/11583))
+
+Internal Changes
+----------------
+
+- Work around a build problem on Debian Buster. ([\#11625](https://github.com/matrix-org/synapse/issues/11625))
+
+
+Synapse 1.49.1 (2021-12-21)
+===========================
+
+Not released due to problems building the debian packages.
+
+
+Synapse 1.49.0 (2021-12-14)
+===========================
+
+No significant changes since version 1.49.0rc1.
+
+
+Support for Ubuntu 21.04 ends next month on the 20th of January
+---------------------------------------------------------------
+
+For users of Ubuntu 21.04 (Hirsute Hippo), please be aware that [upstream support for this version of Ubuntu will end next month][Ubuntu2104EOL].
+We will stop producing packages for Ubuntu 21.04 after upstream support ends.
+
+[Ubuntu2104EOL]: https://lists.ubuntu.com/archives/ubuntu-announce/2021-December/000275.html
+
+
+The wiki has been migrated to the documentation website
+-------------------------------------------------------
+
 We've decided to move the existing, somewhat stagnant pages from the GitHub wiki
 to the [documentation website](https://matrix-org.github.io/synapse/latest/).
 
@@ -16,6 +144,9 @@ requests](https://github.com/matrix-org/synapse/pulls). Please visit [#synapse-d
 if you need help with the process!
 
 
+Synapse 1.49.0rc1 (2021-12-07)
+==============================
+
 Features
 --------
 
diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml
index 26d640c448..5ac41139e3 100644
--- a/contrib/docker/docker-compose.yml
+++ b/contrib/docker/docker-compose.yml
@@ -14,6 +14,7 @@ services:
     # failure
     restart: unless-stopped
     # See the readme for a full documentation of the environment settings
+    # NOTE: You must edit homeserver.yaml to use postgres, it defaults to sqlite
     environment:
       - SYNAPSE_CONFIG_PATH=/data/homeserver.yaml
     volumes:
diff --git a/debian/changelog b/debian/changelog
index acc9f6049e..b54c0ff348 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,27 @@
+matrix-synapse-py3 (1.50.0~rc1) stable; urgency=medium
+
+  * New synapse release 1.50.0~rc1.
+
+ -- Synapse Packaging team <packages@matrix.org>  Wed, 05 Jan 2022 12:36:17 +0000
+
+matrix-synapse-py3 (1.49.2) stable; urgency=medium
+
+  * New synapse release 1.49.2.
+
+ -- Synapse Packaging team <packages@matrix.org>  Tue, 21 Dec 2021 17:31:03 +0000
+
+matrix-synapse-py3 (1.49.1) stable; urgency=medium
+
+  * New synapse release 1.49.1.
+
+ -- Synapse Packaging team <packages@matrix.org>  Tue, 21 Dec 2021 11:07:30 +0000
+
+matrix-synapse-py3 (1.49.0) stable; urgency=medium
+
+  * New synapse release 1.49.0.
+
+ -- Synapse Packaging team <packages@matrix.org>  Tue, 14 Dec 2021 12:39:46 +0000
+
 matrix-synapse-py3 (1.49.0~rc1) stable; urgency=medium
 
   * New synapse release 1.49.0~rc1.
diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv
index 1dd88140c7..fbc1d2346f 100644
--- a/docker/Dockerfile-dhvirtualenv
+++ b/docker/Dockerfile-dhvirtualenv
@@ -16,7 +16,7 @@ ARG distro=""
 ### Stage 0: build a dh-virtualenv
 ###
 
-# This is only really needed on bionic and focal, since other distributions we
+# This is only really needed on focal, since other distributions we
 # care about have a recent version of dh-virtualenv by default. Unfortunately,
 # it looks like focal is going to be with us for a while.
 #
@@ -36,9 +36,8 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
         wget
 
 # fetch and unpack the package
-# TODO: Upgrade to 1.2.2 once bionic is dropped (1.2.2 requires debhelper 12; bionic has only 11)
 RUN mkdir /dh-virtualenv
-RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz
+RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/refs/tags/1.2.2.tar.gz
 RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz
 
 # install its build deps. We do another apt-cache-update here, because we might
@@ -86,12 +85,12 @@ RUN apt-get update -qq -o Acquire::Languages=none \
         libpq-dev \
         xmlsec1
 
-COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb /
+COPY --from=builder /dh-virtualenv_1.2.2-1_all.deb /
 
 # install dhvirtualenv. Update the apt cache again first, in case we got a
 # cached cache from docker the first time.
 RUN apt-get update -qq -o Acquire::Languages=none \
-    && apt-get install -yq /dh-virtualenv_1.2~dev-1_all.deb
+    && apt-get install -yq /dh-virtualenv_1.2.2-1_all.deb
 
 WORKDIR /synapse/source
 ENTRYPOINT ["bash","/synapse/source/docker/build_debian.sh"]
diff --git a/docker/README.md b/docker/README.md
index 4349e71f87..67c3bc65f0 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -68,6 +68,10 @@ The following environment variables are supported in `generate` mode:
   directories. If unset, and no user is set via `docker run --user`, defaults
   to `991`, `991`.
 
+## Postgres
+
+By default the config will use SQLite. See the [docs on using Postgres](https://github.com/matrix-org/synapse/blob/develop/docs/postgres.md) for more info on how to use Postgres. Until this section is improved [this issue](https://github.com/matrix-org/synapse/issues/8304) may provide useful information.
+
 ## Running synapse
 
 Once you have a valid configuration file, you can start synapse as follows:
diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md
index b05af6d690..11f597b3ed 100644
--- a/docs/SUMMARY.md
+++ b/docs/SUMMARY.md
@@ -30,6 +30,7 @@
         - [SSO Mapping Providers](sso_mapping_providers.md)
       - [Password Auth Providers](password_auth_providers.md)
       - [JSON Web Tokens](jwt.md)
+      - [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md)
     - [Registration Captcha](CAPTCHA_SETUP.md)
     - [Application Services](application_services.md)
     - [Server Notices](server_notices.md)
diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md
index ba574d795f..74933d2fcf 100644
--- a/docs/admin_api/user_admin_api.md
+++ b/docs/admin_api/user_admin_api.md
@@ -480,6 +480,81 @@ The following fields are returned in the JSON response body:
 - `joined_rooms` - An array of `room_id`.
 - `total` - Number of rooms.
 
+## Account Data
+Gets information about account data for a specific `user_id`.
+
+The API is:
+
+```
+GET /_synapse/admin/v1/users/<user_id>/accountdata
+```
+
+A response body like the following is returned:
+
+```json
+{
+    "account_data": {
+        "global": {
+            "m.secret_storage.key.LmIGHTg5W": {
+                "algorithm": "m.secret_storage.v1.aes-hmac-sha2",
+                "iv": "fwjNZatxg==",
+                "mac": "eWh9kNnLWZUNOgnc="
+            },
+            "im.vector.hide_profile": {
+                "hide_profile": true
+            },
+            "org.matrix.preview_urls": {
+                "disable": false
+            },
+            "im.vector.riot.breadcrumb_rooms": {
+                "rooms": [
+                    "!LxcBDAsDUVAfJDEo:matrix.org",
+                    "!MAhRxqasbItjOqxu:matrix.org"
+                ]
+            },
+            "m.accepted_terms": {
+                "accepted": [
+                    "https://example.org/somewhere/privacy-1.2-en.html",
+                    "https://example.org/somewhere/terms-2.0-en.html"
+                ]
+            },
+            "im.vector.setting.breadcrumbs": {
+                "recent_rooms": [
+                    "!MAhRxqasbItqxuEt:matrix.org",
+                    "!ZtSaPCawyWtxiImy:matrix.org"
+                ]
+            }
+        },
+        "rooms": {
+            "!GUdfZSHUJibpiVqHYd:matrix.org": {
+                "m.fully_read": {
+                    "event_id": "$156334540fYIhZ:matrix.org"
+                }
+            },
+            "!tOZwOOiqwCYQkLhV:matrix.org": {
+                "m.fully_read": {
+                    "event_id": "$xjsIyp4_NaVl2yPvIZs_k1Jl8tsC_Sp23wjqXPno"
+                }
+            }
+        }
+    }
+}
+```
+
+**Parameters**
+
+The following parameters should be set in the URL:
+
+- `user_id` - fully qualified: for example, `@user:server.com`.
+
+**Response**
+
+The following fields are returned in the JSON response body:
+
+- `account_data` - A map containing the account data for the user
+  - `global` - A map containing the global account data for the user
+  - `rooms` - A map containing the account data per room for the user
+
 ## User media
 
 ### List media uploaded by a user
diff --git a/docs/deprecation_policy.md b/docs/deprecation_policy.md
index 06ea340559..359dac07c3 100644
--- a/docs/deprecation_policy.md
+++ b/docs/deprecation_policy.md
@@ -14,8 +14,8 @@ i.e. when a version reaches End of Life Synapse will withdraw support for that
 version in future releases.
 
 Details on the upstream support life cycles for Python and PostgreSQL are
-documented at https://endoflife.date/python and
-https://endoflife.date/postgresql.
+documented at [https://endoflife.date/python](https://endoflife.date/python) and
+[https://endoflife.date/postgresql](https://endoflife.date/postgresql).
 
 
 Context
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index f3b3aea732..1a89da50fd 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -63,7 +63,7 @@ server {
 
     server_name matrix.example.com;
 
-    location ~* ^(\/_matrix|\/_synapse\/client) {
+    location ~ ^(/_matrix|/_synapse/client) {
         # note: do not add a path (even a single /) after the port in `proxy_pass`,
         # otherwise nginx will canonicalise the URI and cause signature verification
         # errors.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 6696ed5d1e..810a14b077 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -37,7 +37,7 @@
 
 # Server admins can expand Synapse's functionality with external modules.
 #
-# See https://matrix-org.github.io/synapse/latest/modules.html for more
+# See https://matrix-org.github.io/synapse/latest/modules/index.html for more
 # documentation on how to configure or create custom modules for Synapse.
 #
 modules:
@@ -1488,6 +1488,7 @@ room_prejoin_state:
    # - m.room.encryption
    # - m.room.name
    # - m.room.create
+   # - m.room.topic
    #
    # Uncomment the following to disable these defaults (so that only the event
    # types listed in 'additional_event_types' are shared). Defaults to 'false'.
diff --git a/docs/setup/installation.md b/docs/setup/installation.md
index 16562be953..210c80dace 100644
--- a/docs/setup/installation.md
+++ b/docs/setup/installation.md
@@ -164,7 +164,7 @@ xbps-install -S synapse
 Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
 
 - Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
-- Packages: `pkg install py37-matrix-synapse`
+- Packages: `pkg install py38-matrix-synapse`
 
 #### OpenBSD
 
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 7a407012e0..7b4ddc5b74 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -49,12 +49,12 @@ comment these options out and use those specified by the module instead.
 
 A custom mapping provider must specify the following methods:
 
-* `__init__(self, parsed_config)`
+* `def __init__(self, parsed_config)`
    - Arguments:
      - `parsed_config` - A configuration object that is the return value of the
        `parse_config` method. You should set any configuration options needed by
        the module here.
-* `parse_config(config)`
+* `def parse_config(config)`
     - This method should have the `@staticmethod` decoration.
     - Arguments:
         - `config` - A `dict` representing the parsed content of the
@@ -63,13 +63,13 @@ A custom mapping provider must specify the following methods:
            any option values they need here.
     - Whatever is returned will be passed back to the user mapping provider module's
       `__init__` method during construction.
-* `get_remote_user_id(self, userinfo)`
+* `def get_remote_user_id(self, userinfo)`
     - Arguments:
       - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
                      information from.
     - This method must return a string, which is the unique, immutable identifier
       for the user. Commonly the `sub` claim of the response.
-* `map_user_attributes(self, userinfo, token, failures)`
+* `async def map_user_attributes(self, userinfo, token, failures)`
     - This method must be async.
     - Arguments:
       - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
@@ -91,7 +91,7 @@ A custom mapping provider must specify the following methods:
         during a user's first login. Once a localpart has been associated with a
         remote user ID (see `get_remote_user_id`) it cannot be updated.
       - `displayname`: An optional string, the display name for the user.
-* `get_extra_attributes(self, userinfo, token)`
+* `async def get_extra_attributes(self, userinfo, token)`
     - This method must be async.
     - Arguments:
       - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
@@ -125,15 +125,15 @@ comment these options out and use those specified by the module instead.
 
 A custom mapping provider must specify the following methods:
 
-* `__init__(self, parsed_config, module_api)`
+* `def __init__(self, parsed_config, module_api)`
    - Arguments:
      - `parsed_config` - A configuration object that is the return value of the
        `parse_config` method. You should set any configuration options needed by
        the module here.
      - `module_api` - a `synapse.module_api.ModuleApi` object which provides the
        stable API available for extension modules.
-* `parse_config(config)`
-    - This method should have the `@staticmethod` decoration.
+* `def parse_config(config)`
+    - **This method should have the `@staticmethod` decoration.**
     - Arguments:
         - `config` - A `dict` representing the parsed content of the
           `saml_config.user_mapping_provider.config` homeserver config option.
@@ -141,15 +141,15 @@ A custom mapping provider must specify the following methods:
            any option values they need here.
     - Whatever is returned will be passed back to the user mapping provider module's
       `__init__` method during construction.
-* `get_saml_attributes(config)`
-    - This method should have the `@staticmethod` decoration.
+* `def get_saml_attributes(config)`
+    - **This method should have the `@staticmethod` decoration.**
     - Arguments:
         - `config` - A object resulting from a call to `parse_config`.
     - Returns a tuple of two sets. The first set equates to the SAML auth
       response attributes that are required for the module to function, whereas
       the second set consists of those attributes which can be used if available,
       but are not necessary.
-* `get_remote_user_id(self, saml_response, client_redirect_url)`
+* `def get_remote_user_id(self, saml_response, client_redirect_url)`
     - Arguments:
       - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
                           information from.
@@ -157,7 +157,7 @@ A custom mapping provider must specify the following methods:
                                 redirected to.
     - This method must return a string, which is the unique, immutable identifier
       for the user. Commonly the `uid` claim of the response.
-* `saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
+* `def saml_response_to_user_attributes(self, saml_response, failures, client_redirect_url)`
     - Arguments:
       - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
                           information from.
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index e6812de69e..e32aaa1850 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -15,8 +15,8 @@ The following sections describe how to install [coturn](<https://github.com/cotu
 
 For TURN relaying with `coturn` to work, it must be hosted on a server/endpoint with a public IP.
 
-Hosting TURN behind a NAT (even with appropriate port forwarding) is known to cause issues
-and to often not work.
+Hosting TURN behind NAT requires port forwaring and for the NAT gateway to have a public IP.
+However, even with appropriate configuration, NAT is known to cause issues and to often not work.
 
 ## `coturn` setup
 
@@ -103,7 +103,23 @@ This will install and start a systemd service called `coturn`.
     denied-peer-ip=192.168.0.0-192.168.255.255
     denied-peer-ip=172.16.0.0-172.31.255.255
 
+    # recommended additional local peers to block, to mitigate external access to internal services.
+    # https://www.rtcsec.com/article/slack-webrtc-turn-compromise-and-bug-bounty/#how-to-fix-an-open-turn-relay-to-address-this-vulnerability
+    no-multicast-peers
+    denied-peer-ip=0.0.0.0-0.255.255.255
+    denied-peer-ip=100.64.0.0-100.127.255.255
+    denied-peer-ip=127.0.0.0-127.255.255.255
+    denied-peer-ip=169.254.0.0-169.254.255.255
+    denied-peer-ip=192.0.0.0-192.0.0.255
+    denied-peer-ip=192.0.2.0-192.0.2.255
+    denied-peer-ip=192.88.99.0-192.88.99.255
+    denied-peer-ip=198.18.0.0-198.19.255.255
+    denied-peer-ip=198.51.100.0-198.51.100.255
+    denied-peer-ip=203.0.113.0-203.0.113.255
+    denied-peer-ip=240.0.0.0-255.255.255.255
+
     # special case the turn server itself so that client->TURN->TURN->client flows work
+    # this should be one of the turn server's listening IPs
     allowed-peer-ip=10.0.0.1
 
     # consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS.
@@ -123,7 +139,7 @@ This will install and start a systemd service called `coturn`.
     pkey=/path/to/privkey.pem
     ```
 
-    In this case, replace the `turn:` schemes in the `turn_uri` settings below
+    In this case, replace the `turn:` schemes in the `turn_uris` settings below
     with `turns:`.
 
     We recommend that you only try to set up TLS/DTLS once you have set up a
@@ -134,21 +150,33 @@ This will install and start a systemd service called `coturn`.
     traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535
     for the UDP relay.)
 
-1.  We do not recommend running a TURN server behind NAT, and are not aware of
-    anyone doing so successfully.
+1.  If your TURN server is behind NAT, the NAT gateway must have an external,
+    publicly-reachable IP address. You must configure coturn to advertise that
+    address to connecting clients:
+
+    ```
+    external-ip=EXTERNAL_NAT_IPv4_ADDRESS
+    ```
 
-    If you want to try it anyway, you will at least need to tell coturn its
-    external IP address:
+    You may optionally limit the TURN server to listen only on the local
+    address that is mapped by NAT to the external address:
 
     ```
-    external-ip=192.88.99.1
+    listening-ip=INTERNAL_TURNSERVER_IPv4_ADDRESS
     ```
 
-    ... and your NAT gateway must forward all of the relayed ports directly
-    (eg, port 56789 on the external IP must be always be forwarded to port
-    56789 on the internal IP).
+    If your NAT gateway is reachable over both IPv4 and IPv6, you may
+    configure coturn to advertise each available address:
 
-    If you get this working, let us know!
+    ```
+    external-ip=EXTERNAL_NAT_IPv4_ADDRESS
+    external-ip=EXTERNAL_NAT_IPv6_ADDRESS
+    ```
+
+    When advertising an external IPv6 address, ensure that the firewall and
+    network settings of the system running your TURN server are configured to
+    accept IPv6 traffic, and that the TURN server is listening on the local
+    IPv6 address that is mapped by NAT to the external IPv6 address.
 
 1.  (Re)start the turn server:
 
@@ -216,9 +244,6 @@ connecting". Unfortunately, troubleshooting this can be tricky.
 
 Here are a few things to try:
 
- * Check that your TURN server is not behind NAT. As above, we're not aware of
-   anyone who has successfully set this up.
-
  * Check that you have opened your firewall to allow TCP and UDP traffic to the
    TURN ports (normally 3478 and 5349).
 
@@ -234,6 +259,18 @@ Here are a few things to try:
    Try removing any AAAA records for your TURN server, so that it is only
    reachable over IPv4.
 
+ * If your TURN server is behind NAT:
+
+    * double-check that your NAT gateway is correctly forwarding all TURN
+      ports (normally 3478 & 5349 for TCP & UDP TURN traffic, and 49152-65535 for the UDP
+      relay) to the NAT-internal address of your TURN server. If advertising
+      both IPv4 and IPv6 external addresses via the `external-ip` option, ensure
+      that the NAT is forwarding both IPv4 and IPv6 traffic to the IPv4 and IPv6
+      internal addresses of your TURN server. When in doubt, remove AAAA records
+      for your TURN server and specify only an IPv4 address as your `external-ip`.
+
+    * ensure that your TURN server uses the NAT gateway as its default route.
+
  * Enable more verbose logging in coturn via the `verbose` setting:
 
    ```
diff --git a/docs/upgrade.md b/docs/upgrade.md
index 136c806c41..30bb0dcd9c 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -85,6 +85,17 @@ process, for example:
     dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
     ```
 
+# Upgrading to v1.50.0
+
+## Dropping support for old Python and Postgres versions
+
+In line with our [deprecation policy](deprecation_policy.md),
+we've dropped support for Python 3.6 and PostgreSQL 9.6, as they are no
+longer supported upstream.
+
+This release of Synapse requires Python 3.7+ and PostgreSQL 10+.
+
+
 # Upgrading to v1.47.0
 
 ## Removal of old Room Admin API
diff --git a/docs/usage/configuration/user_authentication/refresh_tokens.md b/docs/usage/configuration/user_authentication/refresh_tokens.md
new file mode 100644
index 0000000000..23b3cddae0
--- /dev/null
+++ b/docs/usage/configuration/user_authentication/refresh_tokens.md
@@ -0,0 +1,139 @@
+# Refresh Tokens
+
+Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible).
+
+
+[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens
+
+
+## Background and motivation
+
+Synapse users' sessions are identified by **access tokens**; access tokens are
+issued to users on login. Each session gets a unique access token which identifies
+it; the access token must be kept secret as it grants access to the user's account.
+
+Traditionally, these access tokens were eternally valid (at least until the user
+explicitly chose to log out).
+
+In some cases, it may be desirable for these access tokens to expire so that the
+potential damage caused by leaking an access token is reduced.
+On the other hand, forcing a user to re-authenticate (log in again) often might
+be too much of an inconvenience.
+
+**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst
+still getting most of the benefits of short access token lifetimes.
+Refresh tokens are also a concept present in OAuth 2 — further reading is available
+[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5).
+
+When refresh tokens are in use, both an access token and a refresh token will be
+issued to users on login. The access token will expire after a predetermined amount
+of time, but otherwise works in the same way as before. When the access token is
+close to expiring (or has expired), the user's client should present the homeserver
+(Synapse) with the refresh token.
+
+The homeserver will then generate a new access token and refresh token for the user
+and return them. The old refresh token is invalidated and can not be used again*.
+
+Finally, refresh tokens also make it possible for sessions to be logged out if they
+are inactive for too long, before the session naturally ends; see the configuration
+guide below.
+
+
+*To prevent issues if clients lose connection half-way through refreshing a token,
+the refresh token is only invalidated once the new access token has been used at
+least once. For all intents and purposes, the above simplification is sufficient.
+
+
+## Caveats
+
+There are some caveats:
+
+* If a third party gets both your access token and refresh token, they will be able to
+  continue to enjoy access to your session.
+  * This is still an improvement because you (the user) will notice when *your*
+    session expires and you're not able to use your refresh token.
+    That would be a giveaway that someone else has compromised your session.
+    You would be able to log in again and terminate that session.
+    Previously (with long-lived access tokens), a third party that has your access
+    token could go undetected for a very long time.
+* Clients need to implement support for refresh tokens in order for them to be a
+  useful mechanism.
+  * It is up to homeserver administrators if they want to issue long-lived access
+    tokens to clients not implementing refresh tokens.
+    * For compatibility, it is likely that they should, at least until client support
+      is widespread.
+      * Users with clients that support refresh tokens will still benefit from the
+        added security; it's not possible to downgrade a session to using long-lived
+        access tokens so this effectively gives users the choice.
+    * In a closed environment where all users use known clients, this may not be
+      an issue as the homeserver administrator can know if the clients have refresh
+      token support. In that case, the non-refreshable access token lifetime
+      may be set to a short duration so that a similar level of security is provided.
+
+
+## Configuration Guide
+
+The following configuration options, in the `registration` section, are related:
+
+* `session_lifetime`: maximum length of a session, even if it's refreshed.
+  In other words, the client must log in again after this time period.
+  In most cases, this can be unset (infinite) or set to a long time (years or months).
+* `refreshable_access_token_lifetime`: lifetime of access tokens that are created
+  by clients supporting refresh tokens.
+  This should be short; a good value might be 5 minutes (`5m`).
+* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created
+  by clients which don't support refresh tokens.
+  Make this short if you want to effectively force use of refresh tokens.
+  Make this long if you don't want to inconvenience users of clients which don't
+  support refresh tokens (by forcing them to frequently re-authenticate using
+  login credentials).
+* `refresh_token_lifetime`: lifetime of refresh tokens.
+  In other words, the client must refresh within this time period to maintain its session.
+  Unless you want to log inactive sessions out, it is often fine to use a long
+  value here or even leave it unset (infinite).
+  Beware that making it too short will inconvenience clients that do not connect
+  very often, including mobile clients and clients of infrequent users (by making
+  it more difficult for them to refresh in time, which may force them to need to
+  re-authenticate using login credentials).
+
+**Note:** All four options above only apply when tokens are created (by logging in or refreshing).
+Changes to these settings do not apply retroactively.
+
+
+### Using refresh token expiry to log out inactive sessions
+
+If you'd like to force sessions to be logged out upon inactivity, you can enable
+refreshable access token expiry and refresh token expiry.
+
+This works because a client must refresh at least once within a period of
+`refresh_token_lifetime` in order to maintain valid credentials to access the
+account.
+
+(It's suggested that `refresh_token_lifetime` should be longer than
+`refreshable_access_token_lifetime` and this section assumes that to be the case
+for simplicity.)
+
+Note: this will only affect sessions using refresh tokens. You may wish to
+set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed
+by clients that do not support refresh tokens.
+
+
+#### Choosing values that guarantee permitting some inactivity
+
+It may be desirable to permit some short periods of inactivity, for example to
+accommodate brief outages in client connectivity.
+
+The following model aims to provide guidance for choosing `refresh_token_lifetime`
+and `refreshable_access_token_lifetime` to satisfy requirements of the form:
+
+1. inactivity longer than `L` **MUST** cause the session to be logged out; and
+2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out.
+
+This model makes the weakest assumption that all active clients will refresh as
+needed to maintain an active access token, but no sooner.
+*In reality, clients may refresh more often than this model assumes, but the
+above requirements will still hold.*
+
+To satisfy the above model,
+* `refresh_token_lifetime` should be set to `L`; and
+* `refreshable_access_token_lifetime` should be set to `L - S`.
diff --git a/mypy.ini b/mypy.ini
index 1caf807e85..85fa22d28f 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -25,14 +25,9 @@ exclude = (?x)
   ^(
    |synapse/storage/databases/__init__.py
    |synapse/storage/databases/main/__init__.py
-   |synapse/storage/databases/main/account_data.py
    |synapse/storage/databases/main/cache.py
    |synapse/storage/databases/main/devices.py
-   |synapse/storage/databases/main/e2e_room_keys.py
-   |synapse/storage/databases/main/end_to_end_keys.py
    |synapse/storage/databases/main/event_federation.py
-   |synapse/storage/databases/main/event_push_actions.py
-   |synapse/storage/databases/main/events_bg_updates.py
    |synapse/storage/databases/main/group_server.py
    |synapse/storage/databases/main/metrics.py
    |synapse/storage/databases/main/monthly_active_users.py
@@ -40,12 +35,9 @@ exclude = (?x)
    |synapse/storage/databases/main/purge_events.py
    |synapse/storage/databases/main/push_rule.py
    |synapse/storage/databases/main/receipts.py
-   |synapse/storage/databases/main/room.py
    |synapse/storage/databases/main/roommember.py
    |synapse/storage/databases/main/search.py
    |synapse/storage/databases/main/state.py
-   |synapse/storage/databases/main/stats.py
-   |synapse/storage/databases/main/transactions.py
    |synapse/storage/databases/main/user_directory.py
    |synapse/storage/schema/
 
@@ -107,7 +99,6 @@ exclude = (?x)
    |tests/server.py
    |tests/server_notices/test_resource_limits_server_notices.py
    |tests/state/test_v2.py
-   |tests/storage/test_account_data.py
    |tests/storage/test_background_update.py
    |tests/storage/test_base.py
    |tests/storage/test_client_ips.py
@@ -145,6 +136,9 @@ disallow_untyped_defs = True
 [mypy-synapse.app.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.appservice.*]
+disallow_untyped_defs = True
+
 [mypy-synapse.config._base]
 disallow_untyped_defs = True
 
@@ -163,6 +157,12 @@ disallow_untyped_defs = False
 [mypy-synapse.handlers.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.http.server]
+disallow_untyped_defs = True
+
+[mypy-synapse.logging.context]
+disallow_untyped_defs = True
+
 [mypy-synapse.metrics.*]
 disallow_untyped_defs = True
 
@@ -181,24 +181,48 @@ disallow_untyped_defs = True
 [mypy-synapse.state.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.account_data]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.client_ips]
 disallow_untyped_defs = True
 
 [mypy-synapse.storage.databases.main.directory]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.e2e_room_keys]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.databases.main.end_to_end_keys]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.databases.main.event_push_actions]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.databases.main.events_bg_updates]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.events_worker]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.room]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.room_batch]
 disallow_untyped_defs = True
 
 [mypy-synapse.storage.databases.main.profile]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.stats]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.state_deltas]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.transactions]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.user_erasure_store]
 disallow_untyped_defs = True
 
@@ -223,6 +247,9 @@ disallow_untyped_defs = True
 [mypy-tests.storage.test_user_directory]
 disallow_untyped_defs = True
 
+[mypy-tests.rest.admin.*]
+disallow_untyped_defs = True
+
 [mypy-tests.rest.client.test_directory]
 disallow_untyped_defs = True
 
@@ -286,9 +313,6 @@ ignore_missing_imports = True
 [mypy-netaddr]
 ignore_missing_imports = True
 
-[mypy-opentracing]
-ignore_missing_imports = True
-
 [mypy-parameterized.*]
 ignore_missing_imports = True
 
diff --git a/pyproject.toml b/pyproject.toml
index 8bca1fa4ef..963f149c6a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,7 @@
         showcontent = true
 
 [tool.black]
-target-version = ['py36']
+target-version = ['py37', 'py38', 'py39', 'py310']
 exclude = '''
 
 (
diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages
index 3a9a2d257c..4d34e90703 100755
--- a/scripts-dev/build_debian_packages
+++ b/scripts-dev/build_debian_packages
@@ -24,7 +24,6 @@ DISTS = (
     "debian:bullseye",
     "debian:bookworm",
     "debian:sid",
-    "ubuntu:bionic",  # 18.04 LTS (our EOL forced by Py36 on 2021-12-23)
     "ubuntu:focal",  # 20.04 LTS (our EOL forced by Py38 on 2024-10-14)
     "ubuntu:hirsute",  # 21.04 (EOL 2022-01-05)
     "ubuntu:impish",  # 21.10  (EOL 2022-07)
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index af4de345df..c764011d6a 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -42,8 +42,8 @@ echo "--------------------------"
 echo
 
 matched=0
-for f in $(git diff --name-only FETCH_HEAD... -- changelog.d); do
-    # check that any modified newsfiles on this branch end with a full stop.
+for f in $(git diff --diff-filter=d --name-only FETCH_HEAD... -- changelog.d); do
+    # check that any added newsfiles on this branch end with a full stop.
     lastchar=$(tr -d '\n' < "$f" | tail -c 1)
     if [ "$lastchar" != '.' ] && [ "$lastchar" != '!' ]; then
         echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2
diff --git a/setup.py b/setup.py
index 2c6fb9aacb..e618ff898b 100755
--- a/setup.py
+++ b/setup.py
@@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
 # We pin black so that our tests don't start failing on new releases.
 CONDITIONAL_REQUIREMENTS["lint"] = [
     "isort==5.7.0",
-    "black==21.6b0",
+    "black==21.12b0",
     "flake8-comprehensions",
     "flake8-bugbear==21.3.2",
     "flake8",
@@ -107,6 +107,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
     "mypy-zope==0.3.2",
     "types-bleach>=4.1.0",
     "types-jsonschema>=3.2.0",
+    "types-opentracing>=2.4.2",
     "types-Pillow>=8.3.4",
     "types-pyOpenSSL>=20.0.7",
     "types-PyYAML>=5.4.10",
@@ -119,9 +120,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
 # Tests assume that all optional dependencies are installed.
 #
 # parameterized_class decorator was introduced in parameterized 0.7.0
-#
-# We use `mock` library as that backports `AsyncMock` to Python 3.6
-CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0", "mock>=4.0.0"]
+CONDITIONAL_REQUIREMENTS["test"] = ["parameterized>=0.7.0"]
 
 CONDITIONAL_REQUIREMENTS["dev"] = (
     CONDITIONAL_REQUIREMENTS["lint"]
@@ -163,7 +162,6 @@ setup(
         "Topic :: Communications :: Chat",
         "License :: OSI Approved :: Apache Software License",
         "Programming Language :: Python :: 3 :: Only",
-        "Programming Language :: Python :: 3.6",
         "Programming Language :: Python :: 3.7",
         "Programming Language :: Python :: 3.8",
         "Programming Language :: Python :: 3.9",
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index 4ff3c6de5f..429234d7ae 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -17,11 +17,12 @@
 from typing import Any, List, Optional, Type, Union
 
 from twisted.internet import protocol
+from twisted.internet.defer import Deferred
 
 class RedisProtocol(protocol.Protocol):
     def publish(self, channel: str, message: bytes): ...
-    async def ping(self) -> None: ...
-    async def set(
+    def ping(self) -> "Deferred[None]": ...
+    def set(
         self,
         key: str,
         value: Any,
@@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol):
         pexpire: Optional[int] = None,
         only_if_not_exists: bool = False,
         only_if_exists: bool = False,
-    ) -> None: ...
-    async def get(self, key: str) -> Any: ...
+    ) -> "Deferred[None]": ...
+    def get(self, key: str) -> "Deferred[Any]": ...
 
 class SubscriberProtocol(RedisProtocol):
     def __init__(self, *args, **kwargs): ...
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 6369f18a53..92aec334e6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.49.0rc1"
+__version__ = "1.50.0rc1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 44883c6663..4a32d430bd 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -32,7 +32,7 @@ from synapse.appservice import ApplicationService
 from synapse.events import EventBase
 from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
-from synapse.logging import opentracing as opentracing
+from synapse.logging.opentracing import active_span, force_tracing, start_active_span
 from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import Requester, StateMap, UserID, create_requester
 from synapse.util.caches.lrucache import LruCache
@@ -149,13 +149,53 @@ class Auth:
                 is invalid.
             AuthError if access is denied for the user in the access token
         """
+        parent_span = active_span()
+        with start_active_span("get_user_by_req"):
+            requester = await self._wrapped_get_user_by_req(
+                request, allow_guest, rights, allow_expired
+            )
+
+            if parent_span:
+                if requester.authenticated_entity in self._force_tracing_for_users:
+                    # request tracing is enabled for this user, so we need to force it
+                    # tracing on for the parent span (which will be the servlet span).
+                    #
+                    # It's too late for the get_user_by_req span to inherit the setting,
+                    # so we also force it on for that.
+                    force_tracing()
+                    force_tracing(parent_span)
+                parent_span.set_tag(
+                    "authenticated_entity", requester.authenticated_entity
+                )
+                parent_span.set_tag("user_id", requester.user.to_string())
+                if requester.device_id is not None:
+                    parent_span.set_tag("device_id", requester.device_id)
+                if requester.app_service is not None:
+                    parent_span.set_tag("appservice_id", requester.app_service.id)
+            return requester
+
+    async def _wrapped_get_user_by_req(
+        self,
+        request: SynapseRequest,
+        allow_guest: bool,
+        rights: str,
+        allow_expired: bool,
+    ) -> Requester:
+        """Helper for get_user_by_req
+
+        Once get_user_by_req has set up the opentracing span, this does the actual work.
+        """
         try:
             ip_addr = request.getClientIP()
             user_agent = get_request_user_agent(request)
 
             access_token = self.get_access_token_from_request(request)
 
-            user_id, app_service = await self._get_appservice_user_id(request)
+            (
+                user_id,
+                device_id,
+                app_service,
+            ) = await self._get_appservice_user_id_and_device_id(request)
             if user_id and app_service:
                 if ip_addr and self._track_appservice_user_ips:
                     await self.store.insert_client_ip(
@@ -163,18 +203,16 @@ class Auth:
                         access_token=access_token,
                         ip=ip_addr,
                         user_agent=user_agent,
-                        device_id="dummy-device",  # stubbed
+                        device_id="dummy-device"
+                        if device_id is None
+                        else device_id,  # stubbed
                     )
 
-                requester = create_requester(user_id, app_service=app_service)
+                requester = create_requester(
+                    user_id, app_service=app_service, device_id=device_id
+                )
 
                 request.requester = user_id
-                if user_id in self._force_tracing_for_users:
-                    opentracing.force_tracing()
-                opentracing.set_tag("authenticated_entity", user_id)
-                opentracing.set_tag("user_id", user_id)
-                opentracing.set_tag("appservice_id", app_service.id)
-
                 return requester
 
             user_info = await self.get_user_by_access_token(
@@ -232,13 +270,6 @@ class Auth:
             )
 
             request.requester = requester
-            if user_info.token_owner in self._force_tracing_for_users:
-                opentracing.force_tracing()
-            opentracing.set_tag("authenticated_entity", user_info.token_owner)
-            opentracing.set_tag("user_id", user_info.user_id)
-            if device_id:
-                opentracing.set_tag("device_id", device_id)
-
             return requester
         except KeyError:
             raise MissingClientTokenError()
@@ -274,33 +305,81 @@ class Auth:
                 403, "Application service has not registered this user (%s)" % user_id
             )
 
-    async def _get_appservice_user_id(
+    async def _get_appservice_user_id_and_device_id(
         self, request: Request
-    ) -> Tuple[Optional[str], Optional[ApplicationService]]:
+    ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
+        """
+        Given a request, reads the request parameters to determine:
+        - whether it's an application service that's making this request
+        - what user the application service should be treated as controlling
+          (the user_id URI parameter allows an application service to masquerade
+          any applicable user in its namespace)
+        - what device the application service should be treated as controlling
+          (the device_id[^1] URI parameter allows an application service to masquerade
+          as any device that exists for the relevant user)
+
+        [^1] Unstable and provided by MSC3202.
+             Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
+
+        Returns:
+            3-tuple of
+            (user ID?, device ID?, application service?)
+
+        Postconditions:
+        - If an application service is returned, so is a user ID
+        - A user ID is never returned without an application service
+        - A device ID is never returned without a user ID or an application service
+        - The returned application service, if present, is permitted to control the
+          returned user ID.
+        - The returned device ID, if present, has been checked to be a valid device ID
+          for the returned user ID.
+        """
+        DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id"
+
         app_service = self.store.get_app_service_by_token(
             self.get_access_token_from_request(request)
         )
         if app_service is None:
-            return None, None
+            return None, None, None
 
         if app_service.ip_range_whitelist:
             ip_address = IPAddress(request.getClientIP())
             if ip_address not in app_service.ip_range_whitelist:
-                return None, None
+                return None, None, None
 
         # This will always be set by the time Twisted calls us.
         assert request.args is not None
 
-        if b"user_id" not in request.args:
-            return app_service.sender, app_service
+        if b"user_id" in request.args:
+            effective_user_id = request.args[b"user_id"][0].decode("utf8")
+            await self.validate_appservice_can_control_user_id(
+                app_service, effective_user_id
+            )
+        else:
+            effective_user_id = app_service.sender
 
-        user_id = request.args[b"user_id"][0].decode("utf8")
-        await self.validate_appservice_can_control_user_id(app_service, user_id)
+        effective_device_id: Optional[str] = None
 
-        if app_service.sender == user_id:
-            return app_service.sender, app_service
+        if (
+            self.hs.config.experimental.msc3202_device_masquerading_enabled
+            and DEVICE_ID_ARG_NAME in request.args
+        ):
+            effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8")
+            # We only just set this so it can't be None!
+            assert effective_device_id is not None
+            device_opt = await self.store.get_device(
+                effective_user_id, effective_device_id
+            )
+            if device_opt is None:
+                # For now, use 400 M_EXCLUSIVE if the device doesn't exist.
+                # This is an open thread of discussion on MSC3202 as of 2021-12-09.
+                raise AuthError(
+                    400,
+                    f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})",
+                    Codes.EXCLUSIVE,
+                )
 
-        return user_id, app_service
+        return effective_user_id, effective_device_id, app_service
 
     async def get_user_by_access_token(
         self,
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index f7d29b4319..52c083a20b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -253,5 +253,9 @@ class GuestAccess:
     FORBIDDEN: Final = "forbidden"
 
 
+class ReceiptTypes:
+    READ: Final = "m.read"
+
+
 class ReadReceiptEventFields:
     MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 13dd6ce248..d087c816db 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -351,8 +351,7 @@ class Filter:
             True if the event matches the filter.
         """
         # We usually get the full "events" as dictionaries coming through,
-        # except for presence which actually gets passed around as its own
-        # namedtuple type.
+        # except for presence which actually gets passed around as its own type.
         if isinstance(event, UserPresenceState):
             user_id = event.user_id
             field_matchers = {
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index dd76e07321..177ce040e8 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -27,6 +27,7 @@ import synapse
 import synapse.config.logger
 from synapse import events
 from synapse.api.urls import (
+    CLIENT_API_PREFIX,
     FEDERATION_PREFIX,
     LEGACY_MEDIA_PREFIX,
     MEDIA_R0_PREFIX,
@@ -192,13 +193,7 @@ class SynapseHomeServer(HomeServer):
 
             resources.update(
                 {
-                    "/_matrix/client/api/v1": client_resource,
-                    "/_matrix/client/r0": client_resource,
-                    "/_matrix/client/v1": client_resource,
-                    "/_matrix/client/v3": client_resource,
-                    "/_matrix/client/unstable": client_resource,
-                    "/_matrix/client/v2_alpha": client_resource,
-                    "/_matrix/client/versions": client_resource,
+                    CLIENT_API_PREFIX: client_resource,
                     "/.well-known": well_known_resource(self),
                     "/_synapse/admin": AdminRestResource(self),
                     **build_synapse_client_resource_tree(self),
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index f9d3bd337d..8c9ff93b2c 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -11,10 +11,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 import re
 from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Match, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern
+
+import attr
+from netaddr import IPSet
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
@@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
     UP = "up"
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Namespace:
+    exclusive: bool
+    group_id: Optional[str]
+    regex: Pattern[str]
+
+
 class ApplicationService:
     """Defines an application service. This definition is mostly what is
     provided to the /register AS API.
@@ -50,17 +61,17 @@ class ApplicationService:
 
     def __init__(
         self,
-        token,
-        hostname,
-        id,
-        sender,
-        url=None,
-        namespaces=None,
-        hs_token=None,
-        protocols=None,
-        rate_limited=True,
-        ip_range_whitelist=None,
-        supports_ephemeral=False,
+        token: str,
+        hostname: str,
+        id: str,
+        sender: str,
+        url: Optional[str] = None,
+        namespaces: Optional[JsonDict] = None,
+        hs_token: Optional[str] = None,
+        protocols: Optional[Iterable[str]] = None,
+        rate_limited: bool = True,
+        ip_range_whitelist: Optional[IPSet] = None,
+        supports_ephemeral: bool = False,
     ):
         self.token = token
         self.url = (
@@ -85,27 +96,33 @@ class ApplicationService:
 
         self.rate_limited = rate_limited
 
-    def _check_namespaces(self, namespaces):
+    def _check_namespaces(
+        self, namespaces: Optional[JsonDict]
+    ) -> Dict[str, List[Namespace]]:
         # Sanity check that it is of the form:
         # {
         #   users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         #   aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         #   rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         # }
-        if not namespaces:
+        if namespaces is None:
             namespaces = {}
 
+        result: Dict[str, List[Namespace]] = {}
+
         for ns in ApplicationService.NS_LIST:
+            result[ns] = []
+
             if ns not in namespaces:
-                namespaces[ns] = []
                 continue
 
-            if type(namespaces[ns]) != list:
+            if not isinstance(namespaces[ns], list):
                 raise ValueError("Bad namespace value for '%s'" % ns)
             for regex_obj in namespaces[ns]:
                 if not isinstance(regex_obj, dict):
                     raise ValueError("Expected dict regex for ns '%s'" % ns)
-                if not isinstance(regex_obj.get("exclusive"), bool):
+                exclusive = regex_obj.get("exclusive")
+                if not isinstance(exclusive, bool):
                     raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
                 group_id = regex_obj.get("group_id")
                 if group_id:
@@ -126,22 +143,26 @@ class ApplicationService:
                         )
 
                 regex = regex_obj.get("regex")
-                if isinstance(regex, str):
-                    regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
-                else:
+                if not isinstance(regex, str):
                     raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
-        return namespaces
 
-    def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
-        for regex_obj in self.namespaces[namespace_key]:
-            if regex_obj["regex"].match(test_string):
-                return regex_obj
+                # Pre-compile regex.
+                result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
+
+        return result
+
+    def _matches_regex(
+        self, namespace_key: str, test_string: str
+    ) -> Optional[Namespace]:
+        for namespace in self.namespaces[namespace_key]:
+            if namespace.regex.match(test_string):
+                return namespace
         return None
 
-    def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
-        regex_obj = self._matches_regex(test_string, ns_key)
-        if regex_obj:
-            return regex_obj["exclusive"]
+    def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
+        namespace = self._matches_regex(namespace_key, test_string)
+        if namespace:
+            return namespace.exclusive
         return False
 
     async def _matches_user(
@@ -260,15 +281,15 @@ class ApplicationService:
 
     def is_interested_in_user(self, user_id: str) -> bool:
         return (
-            bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
+            bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
             or user_id == self.sender
         )
 
     def is_interested_in_alias(self, alias: str) -> bool:
-        return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
+        return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
 
     def is_interested_in_room(self, room_id: str) -> bool:
-        return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
+        return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
 
     def is_exclusive_user(self, user_id: str) -> bool:
         return (
@@ -285,14 +306,14 @@ class ApplicationService:
     def is_exclusive_room(self, room_id: str) -> bool:
         return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
 
-    def get_exclusive_user_regexes(self):
+    def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
         """Get the list of regexes used to determine if a user is exclusively
         registered by the AS
         """
         return [
-            regex_obj["regex"]
-            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
-            if regex_obj["exclusive"]
+            namespace.regex
+            for namespace in self.namespaces[ApplicationService.NS_USERS]
+            if namespace.exclusive
         ]
 
     def get_groups_for_user(self, user_id: str) -> Iterable[str]:
@@ -305,15 +326,15 @@ class ApplicationService:
             An iterable that yields group_id strings.
         """
         return (
-            regex_obj["group_id"]
-            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
-            if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+            namespace.group_id
+            for namespace in self.namespaces[ApplicationService.NS_USERS]
+            if namespace.group_id and namespace.regex.match(user_id)
         )
 
     def is_rate_limited(self) -> bool:
         return self.rate_limited
 
-    def __str__(self):
+    def __str__(self) -> str:
         # copy dictionary and redact token fields so they don't get logged
         dict_copy = self.__dict__.copy()
         dict_copy["token"] = "<redacted>"
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f51b636417..def4424af0 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
-from typing import TYPE_CHECKING, List, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 from prometheus_client import Counter
 
@@ -53,7 +53,7 @@ HOUR_IN_MS = 60 * 60 * 1000
 APP_SERVICE_PREFIX = "/_matrix/app/unstable"
 
 
-def _is_valid_3pe_metadata(info):
+def _is_valid_3pe_metadata(info: JsonDict) -> bool:
     if "instances" not in info:
         return False
     if not isinstance(info["instances"], list):
@@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info):
     return True
 
 
-def _is_valid_3pe_result(r, field):
+def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
     if not isinstance(r, dict):
         return False
 
@@ -93,9 +93,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
         )
 
-    async def query_user(self, service, user_id):
+    async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
         if service.url is None:
             return False
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
         uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
         try:
             response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -109,9 +113,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_user to %s threw exception %s", uri, ex)
         return False
 
-    async def query_alias(self, service, alias):
+    async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
         if service.url is None:
             return False
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
         uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
         try:
             response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -125,7 +133,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_alias to %s threw exception %s", uri, ex)
         return False
 
-    async def query_3pe(self, service, kind, protocol, fields):
+    async def query_3pe(
+        self,
+        service: "ApplicationService",
+        kind: str,
+        protocol: str,
+        fields: Dict[bytes, List[bytes]],
+    ) -> List[JsonDict]:
         if kind == ThirdPartyEntityKind.USER:
             required_field = "userid"
         elif kind == ThirdPartyEntityKind.LOCATION:
@@ -205,11 +219,14 @@ class ApplicationServiceApi(SimpleHttpClient):
         events: List[EventBase],
         ephemeral: List[JsonDict],
         txn_id: Optional[int] = None,
-    ):
+    ) -> bool:
         if service.url is None:
             return True
 
-        events = self._serialize(service, events)
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
+        serialized_events = self._serialize(service, events)
 
         if txn_id is None:
             logger.warning(
@@ -221,9 +238,12 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         # Never send ephemeral events to appservices that do not support it
         if service.supports_ephemeral:
-            body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
+            body = {
+                "events": serialized_events,
+                "de.sorunome.msc2409.ephemeral": ephemeral,
+            }
         else:
-            body = {"events": events}
+            body = {"events": serialized_events}
 
         try:
             await self.put_json(
@@ -238,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                     [event.get("event_id") for event in events],
                 )
             sent_transactions_counter.labels(service.id).inc()
-            sent_events_counter.labels(service.id).inc(len(events))
+            sent_events_counter.labels(service.id).inc(len(serialized_events))
             return True
         except CodeMessageException as e:
             logger.warning(
@@ -260,7 +280,9 @@ class ApplicationServiceApi(SimpleHttpClient):
         failed_transactions_counter.labels(service.id).inc()
         return False
 
-    def _serialize(self, service, events):
+    def _serialize(
+        self, service: "ApplicationService", events: Iterable[EventBase]
+    ) -> List[JsonDict]:
         time_now = self.clock.time_msec()
         return [
             serialize_event(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 6a2ce99b55..185e3a5278 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,13 +48,19 @@ This is all tied together by the AppServiceScheduler which DIs the required
 components.
 """
 import logging
-from typing import List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.appservice.api import ApplicationServiceApi
 from synapse.events import EventBase
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main import DataStore
 from synapse.types import JsonDict
+from synapse.util import Clock
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -72,7 +78,7 @@ class ApplicationServiceScheduler:
     case is a simple array.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.as_api = hs.get_application_service_api()
@@ -80,7 +86,7 @@ class ApplicationServiceScheduler:
         self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
         self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
 
-    async def start(self):
+    async def start(self) -> None:
         logger.info("Starting appservice scheduler")
 
         # check for any DOWN ASes and start recoverers for them.
@@ -91,12 +97,14 @@ class ApplicationServiceScheduler:
         for service in services:
             self.txn_ctrl.start_recoverer(service)
 
-    def submit_event_for_as(self, service: ApplicationService, event: EventBase):
+    def submit_event_for_as(
+        self, service: ApplicationService, event: EventBase
+    ) -> None:
         self.queuer.enqueue_event(service, event)
 
     def submit_ephemeral_events_for_as(
         self, service: ApplicationService, events: List[JsonDict]
-    ):
+    ) -> None:
         self.queuer.enqueue_ephemeral(service, events)
 
 
@@ -108,16 +116,18 @@ class _ServiceQueuer:
     appservice at a given time.
     """
 
-    def __init__(self, txn_ctrl, clock):
-        self.queued_events = {}  # dict of {service_id: [events]}
-        self.queued_ephemeral = {}  # dict of {service_id: [events]}
+    def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
+        # dict of {service_id: [events]}
+        self.queued_events: Dict[str, List[EventBase]] = {}
+        # dict of {service_id: [events]}
+        self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
 
         # the appservices which currently have a transaction in flight
-        self.requests_in_flight = set()
+        self.requests_in_flight: Set[str] = set()
         self.txn_ctrl = txn_ctrl
         self.clock = clock
 
-    def _start_background_request(self, service):
+    def _start_background_request(self, service: ApplicationService) -> None:
         # start a sender for this appservice if we don't already have one
         if service.id in self.requests_in_flight:
             return
@@ -126,15 +136,17 @@ class _ServiceQueuer:
             "as-sender-%s" % (service.id,), self._send_request, service
         )
 
-    def enqueue_event(self, service: ApplicationService, event: EventBase):
+    def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
         self.queued_events.setdefault(service.id, []).append(event)
         self._start_background_request(service)
 
-    def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
+    def enqueue_ephemeral(
+        self, service: ApplicationService, events: List[JsonDict]
+    ) -> None:
         self.queued_ephemeral.setdefault(service.id, []).extend(events)
         self._start_background_request(service)
 
-    async def _send_request(self, service: ApplicationService):
+    async def _send_request(self, service: ApplicationService) -> None:
         # sanity-check: we shouldn't get here if this service already has a sender
         # running.
         assert service.id not in self.requests_in_flight
@@ -168,20 +180,15 @@ class _TransactionController:
     if a transaction fails.
 
     (Note we have only have one of these in the homeserver.)
-
-    Args:
-        clock (synapse.util.Clock):
-        store (synapse.storage.DataStore):
-        as_api (synapse.appservice.api.ApplicationServiceApi):
     """
 
-    def __init__(self, clock, store, as_api):
+    def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi):
         self.clock = clock
         self.store = store
         self.as_api = as_api
 
         # map from service id to recoverer instance
-        self.recoverers = {}
+        self.recoverers: Dict[str, "_Recoverer"] = {}
 
         # for UTs
         self.RECOVERER_CLASS = _Recoverer
@@ -191,7 +198,7 @@ class _TransactionController:
         service: ApplicationService,
         events: List[EventBase],
         ephemeral: Optional[List[JsonDict]] = None,
-    ):
+    ) -> None:
         try:
             txn = await self.store.create_appservice_txn(
                 service=service, events=events, ephemeral=ephemeral or []
@@ -207,7 +214,7 @@ class _TransactionController:
             logger.exception("Error creating appservice transaction")
             run_in_background(self._on_txn_fail, service)
 
-    async def on_recovered(self, recoverer):
+    async def on_recovered(self, recoverer: "_Recoverer") -> None:
         logger.info(
             "Successfully recovered application service AS ID %s", recoverer.service.id
         )
@@ -217,18 +224,18 @@ class _TransactionController:
             recoverer.service, ApplicationServiceState.UP
         )
 
-    async def _on_txn_fail(self, service):
+    async def _on_txn_fail(self, service: ApplicationService) -> None:
         try:
             await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
             self.start_recoverer(service)
         except Exception:
             logger.exception("Error starting AS recoverer")
 
-    def start_recoverer(self, service):
+    def start_recoverer(self, service: ApplicationService) -> None:
         """Start a Recoverer for the given service
 
         Args:
-            service (synapse.appservice.ApplicationService):
+            service:
         """
         logger.info("Starting recoverer for AS ID %s", service.id)
         assert service.id not in self.recoverers
@@ -257,7 +264,14 @@ class _Recoverer:
         callback (callable[_Recoverer]): called once the service recovers.
     """
 
-    def __init__(self, clock, store, as_api, service, callback):
+    def __init__(
+        self,
+        clock: Clock,
+        store: DataStore,
+        as_api: ApplicationServiceApi,
+        service: ApplicationService,
+        callback: Callable[["_Recoverer"], Awaitable[None]],
+    ):
         self.clock = clock
         self.store = store
         self.as_api = as_api
@@ -265,8 +279,8 @@ class _Recoverer:
         self.callback = callback
         self.backoff_counter = 1
 
-    def recover(self):
-        def _retry():
+    def recover(self) -> None:
+        def _retry() -> None:
             run_as_background_process(
                 "as-recoverer-%s" % (self.service.id,), self.retry
             )
@@ -275,13 +289,13 @@ class _Recoverer:
         logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
         self.clock.call_later(delay, _retry)
 
-    def _backoff(self):
+    def _backoff(self) -> None:
         # cap the backoff to be around 8.5min => (2^9) = 512 secs
         if self.backoff_counter < 9:
             self.backoff_counter += 1
         self.recover()
 
-    async def retry(self):
+    async def retry(self) -> None:
         logger.info("Starting retries on %s", self.service.id)
         try:
             while True:
diff --git a/synapse/config/api.py b/synapse/config/api.py
index b18044f982..25538b82d5 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -107,6 +107,8 @@ _DEFAULT_PREJOIN_STATE_TYPES = [
     EventTypes.Name,
     # Per MSC1772.
     EventTypes.Create,
+    # Per MSC3173.
+    EventTypes.Topic,
 ]
 
 
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index e4bb7224a4..7fad2e0422 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -147,8 +147,7 @@ def _load_appservice(
     # protocols check
     protocols = as_info.get("protocols")
     if protocols:
-        # Because strings are lists in python
-        if isinstance(protocols, str) or not isinstance(protocols, list):
+        if not isinstance(protocols, list):
             raise KeyError("Optional 'protocols' must be a list if present.")
         for p in protocols:
             if not isinstance(p, str):
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index d78a15097c..dbaeb10918 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,7 +32,7 @@ class ExperimentalConfig(Config):
         # MSC3026 (busy presence state)
         self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
 
-        # MSC2716 (backfill existing history)
+        # MSC2716 (importing historical messages)
         self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
 
         # MSC2285 (hidden read receipts)
@@ -49,3 +49,8 @@ class ExperimentalConfig(Config):
 
         # MSC3030 (Jump to date API endpoint)
         self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False)
+
+        # The portion of MSC3202 which is related to device masquerading.
+        self.msc3202_device_masquerading_enabled: bool = experimental.get(
+            "msc3202_device_masquerading", False
+        )
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 035ee2416b..ee83c6c06b 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -16,12 +16,14 @@
 import hashlib
 import logging
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Iterator, List, Optional
 
 import attr
 import jsonschema
 from signedjson.key import (
     NACL_ED25519,
+    SigningKey,
+    VerifyKey,
     decode_signing_key_base64,
     decode_verify_key_bytes,
     generate_signing_key,
@@ -31,6 +33,7 @@ from signedjson.key import (
 )
 from unpaddedbase64 import decode_base64
 
+from synapse.types import JsonDict
 from synapse.util.stringutils import random_string, random_string_with_symbols
 
 from ._base import Config, ConfigError
@@ -81,14 +84,13 @@ To suppress this warning and continue using 'matrix.org', admins should set
 logger = logging.getLogger(__name__)
 
 
-@attr.s
+@attr.s(slots=True, auto_attribs=True)
 class TrustedKeyServer:
-    # string: name of the server.
-    server_name = attr.ib()
+    # name of the server.
+    server_name: str
 
-    # dict[str,VerifyKey]|None: map from key id to key object, or None to disable
-    # signature verification.
-    verify_keys = attr.ib(default=None)
+    # map from key id to key object, or None to disable signature verification.
+    verify_keys: Optional[Dict[str, VerifyKey]] = None
 
 
 class KeyConfig(Config):
@@ -279,15 +281,15 @@ class KeyConfig(Config):
             % locals()
         )
 
-    def read_signing_keys(self, signing_key_path, name):
+    def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
         """Read the signing keys in the given path.
 
         Args:
-            signing_key_path (str)
-            name (str): Associated config key name
+            signing_key_path
+            name: Associated config key name
 
         Returns:
-            list[SigningKey]
+            The signing keys read from the given path.
         """
 
         signing_keys = self.read_file(signing_key_path, name)
@@ -296,7 +298,9 @@ class KeyConfig(Config):
         except Exception as e:
             raise ConfigError("Error reading %s: %s" % (name, str(e)))
 
-    def read_old_signing_keys(self, old_signing_keys):
+    def read_old_signing_keys(
+        self, old_signing_keys: Optional[JsonDict]
+    ) -> Dict[str, VerifyKey]:
         if old_signing_keys is None:
             return {}
         keys = {}
@@ -340,7 +344,7 @@ class KeyConfig(Config):
                     write_signing_keys(signing_key_file, (key,))
 
 
-def _perspectives_to_key_servers(config):
+def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]:
     """Convert old-style 'perspectives' configs into new-style 'trusted_key_servers'
 
     Returns an iterable of entries to add to trusted_key_servers.
@@ -402,7 +406,9 @@ TRUSTED_KEY_SERVERS_SCHEMA = {
 }
 
 
-def _parse_key_servers(key_servers, federation_verify_certificates):
+def _parse_key_servers(
+    key_servers: List[Any], federation_verify_certificates: bool
+) -> Iterator[TrustedKeyServer]:
     try:
         jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
     except jsonschema.ValidationError as e:
@@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
         yield result
 
 
-def _assert_keyserver_has_verify_keys(trusted_key_server):
+def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None:
     if not trusted_key_server.verify_keys:
         raise ConfigError(INSECURE_NOTARY_ERROR)
 
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 7ac82edb0e..1cc26e7578 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -22,10 +22,12 @@ from ._base import Config, ConfigError
 
 @attr.s
 class MetricsFlags:
-    known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
+    known_servers: bool = attr.ib(
+        default=False, validator=attr.validators.instance_of(bool)
+    )
 
     @classmethod
-    def all_off(cls):
+    def all_off(cls) -> "MetricsFlags":
         """
         Instantiate the flags with all options set to off.
         """
diff --git a/synapse/config/modules.py b/synapse/config/modules.py
index ae0821e5a5..85fb05890d 100644
--- a/synapse/config/modules.py
+++ b/synapse/config/modules.py
@@ -37,7 +37,7 @@ class ModulesConfig(Config):
 
             # Server admins can expand Synapse's functionality with external modules.
             #
-            # See https://matrix-org.github.io/synapse/latest/modules.html for more
+            # See https://matrix-org.github.io/synapse/latest/modules/index.html for more
             # documentation on how to configure or create custom modules for Synapse.
             #
             modules:
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index b129b9dd68..1980351e77 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -14,10 +14,11 @@
 
 import logging
 import os
-from collections import namedtuple
 from typing import Dict, List, Tuple
 from urllib.request import getproxies_environment  # type: ignore
 
+import attr
+
 from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
 from synapse.python_dependencies import DependencyException, check_requirements
 from synapse.types import JsonDict
@@ -44,18 +45,20 @@ THUMBNAIL_SIZE_YAML = """\
 HTTP_PROXY_SET_WARNING = """\
 The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
 
-ThumbnailRequirement = namedtuple(
-    "ThumbnailRequirement", ["width", "height", "method", "media_type"]
-)
 
-MediaStorageProviderConfig = namedtuple(
-    "MediaStorageProviderConfig",
-    (
-        "store_local",  # Whether to store newly uploaded local files
-        "store_remote",  # Whether to store newly downloaded remote files
-        "store_synchronous",  # Whether to wait for successful storage for local uploads
-    ),
-)
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThumbnailRequirement:
+    width: int
+    height: int
+    method: str
+    media_type: str
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class MediaStorageProviderConfig:
+    store_local: bool  # Whether to store newly uploaded local files
+    store_remote: bool  # Whether to store newly downloaded remote files
+    store_synchronous: bool  # Whether to wait for successful storage for local uploads
 
 
 def parse_thumbnail_requirements(
@@ -66,11 +69,10 @@ def parse_thumbnail_requirements(
     method, and thumbnail media type to precalculate
 
     Args:
-        thumbnail_sizes(list): List of dicts with "width", "height", and
-            "method" keys
+        thumbnail_sizes: List of dicts with "width", "height", and "method" keys
+
     Returns:
-        Dictionary mapping from media type string to list of
-        ThumbnailRequirement tuples.
+        Dictionary mapping from media type string to list of ThumbnailRequirement.
     """
     requirements: Dict[str, List[ThumbnailRequirement]] = {}
     for size in thumbnail_sizes:
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 57316c59b6..3c5e0f7ce7 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -15,8 +15,9 @@
 
 from typing import List
 
+from matrix_common.regex import glob_to_regex
+
 from synapse.types import JsonDict
-from synapse.util import glob_to_regex
 
 from ._base import Config, ConfigError
 
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ba5b954263..1de2dea9b0 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1257,7 +1257,7 @@ class ServerConfig(Config):
             help="Turn on the twisted telnet manhole service on the given port.",
         )
 
-    def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
+    def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]:
         """Reads the three durations for the GC min interval option, returning seconds."""
         if durations is None:
             return None
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 4ca111618f..6e673d65a7 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -16,11 +16,12 @@ import logging
 import os
 from typing import List, Optional, Pattern
 
+from matrix_common.regex import glob_to_regex
+
 from OpenSSL import SSL, crypto
 from twisted.internet._sslverify import Certificate, trustRootFromCertificates
 
 from synapse.config._base import Config, ConfigError
-from synapse.util import glob_to_regex
 
 logger = logging.getLogger(__name__)
 
@@ -132,7 +133,7 @@ class TlsConfig(Config):
         self.tls_certificate: Optional[crypto.X509] = None
         self.tls_private_key: Optional[crypto.PKey] = None
 
-    def read_certificate_from_disk(self):
+    def read_certificate_from_disk(self) -> None:
         """
         Read the certificates and private key from disk.
         """
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 84ef69df67..2038e72924 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -395,7 +395,7 @@ class EventClientSerializer:
         event: Union[JsonDict, EventBase],
         time_now: int,
         *,
-        bundle_aggregations: bool = True,
+        bundle_aggregations: bool = False,
         **kwargs: Any,
     ) -> JsonDict:
         """Serializes a single event.
@@ -454,23 +454,26 @@ class EventClientSerializer:
                 return
 
         event_id = event.event_id
+        room_id = event.room_id
 
         # The bundled aggregations to include.
         aggregations = {}
 
-        annotations = await self.store.get_aggregation_groups_for_event(event_id)
+        annotations = await self.store.get_aggregation_groups_for_event(
+            event_id, room_id
+        )
         if annotations.chunk:
             aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
 
         references = await self.store.get_relations_for_event(
-            event_id, RelationTypes.REFERENCE, direction="f"
+            event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
             aggregations[RelationTypes.REFERENCE] = references.to_dict()
 
         edit = None
         if event.type == EventTypes.Message:
-            edit = await self.store.get_applicable_edit(event_id)
+            edit = await self.store.get_applicable_edit(event_id, room_id)
 
         if edit:
             # If there is an edit replace the content, preserving existing
@@ -503,7 +506,7 @@ class EventClientSerializer:
             (
                 thread_count,
                 latest_thread_event,
-            ) = await self.store.get_thread_summary(event_id)
+            ) = await self.store.get_thread_summary(event_id, room_id)
             if latest_thread_event:
                 aggregations[RelationTypes.THREAD] = {
                     # Don't bundle aggregations as this could recurse forever.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f56344a3b9..addc0bf000 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from collections import namedtuple
 from typing import TYPE_CHECKING
 
 from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
@@ -104,10 +103,6 @@ class FederationBase:
         return pdu
 
 
-class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
-    pass
-
-
 async def _check_sigs_on_pdu(
     keyring: Keyring, room_version: RoomVersion, pdu: EventBase
 ) -> None:
@@ -220,15 +215,12 @@ def _is_invite_via_3pid(event: EventBase) -> bool:
     )
 
 
-def event_from_pdu_json(
-    pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False
-) -> EventBase:
+def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase:
     """Construct an EventBase from an event json received over federation
 
     Args:
         pdu_json: pdu as received over federation
         room_version: The version of the room this event belongs to
-        outlier: True to mark this event as an outlier
 
     Raises:
         SynapseError: if the pdu is missing required fields or is otherwise
@@ -252,6 +244,4 @@ def event_from_pdu_json(
         validate_canonicaljson(pdu_json)
 
     event = make_event_from_dict(pdu_json, room_version)
-    event.internal_metadata.outlier = outlier
-
     return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index fee1477ab6..6ea4edfc71 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -265,14 +265,11 @@ class FederationClient(FederationBase):
 
         room_version = await self.store.get_room_version(room_id)
 
-        pdus = [
-            event_from_pdu_json(p, room_version, outlier=False)
-            for p in transaction_data_pdus
-        ]
+        pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus]
 
         # Check signatures and hash of pdus, removing any from the list that fail checks
         pdus[:] = await self._check_sigs_and_hash_and_fetch(
-            dest, pdus, outlier=True, room_version=room_version
+            dest, pdus, room_version=room_version
         )
 
         return pdus
@@ -282,7 +279,6 @@ class FederationClient(FederationBase):
         destination: str,
         event_id: str,
         room_version: RoomVersion,
-        outlier: bool = False,
         timeout: Optional[int] = None,
     ) -> Optional[EventBase]:
         """Requests the PDU with given origin and ID from the remote home
@@ -292,9 +288,6 @@ class FederationClient(FederationBase):
             destination: Which homeserver to query
             event_id: event to fetch
             room_version: version of the room
-            outlier: Indicates whether the PDU is an `outlier`, i.e. if
-                it's from an arbitrary point in the context as opposed to part
-                of the current block of PDUs. Defaults to `False`
             timeout: How long to try (in ms) each destination for before
                 moving to the next destination. None indicates no timeout.
 
@@ -316,8 +309,7 @@ class FederationClient(FederationBase):
         )
 
         pdu_list: List[EventBase] = [
-            event_from_pdu_json(p, room_version, outlier=outlier)
-            for p in transaction_data["pdus"]
+            event_from_pdu_json(p, room_version) for p in transaction_data["pdus"]
         ]
 
         if pdu_list and pdu_list[0]:
@@ -334,7 +326,6 @@ class FederationClient(FederationBase):
         destinations: Iterable[str],
         event_id: str,
         room_version: RoomVersion,
-        outlier: bool = False,
         timeout: Optional[int] = None,
     ) -> Optional[EventBase]:
         """Requests the PDU with given origin and ID from the remote home
@@ -347,9 +338,6 @@ class FederationClient(FederationBase):
             destinations: Which homeservers to query
             event_id: event to fetch
             room_version: version of the room
-            outlier: Indicates whether the PDU is an `outlier`, i.e. if
-                it's from an arbitrary point in the context as opposed to part
-                of the current block of PDUs. Defaults to `False`
             timeout: How long to try (in ms) each destination for before
                 moving to the next destination. None indicates no timeout.
 
@@ -377,7 +365,6 @@ class FederationClient(FederationBase):
                     destination=destination,
                     event_id=event_id,
                     room_version=room_version,
-                    outlier=outlier,
                     timeout=timeout,
                 )
 
@@ -435,7 +422,6 @@ class FederationClient(FederationBase):
         origin: str,
         pdus: Collection[EventBase],
         room_version: RoomVersion,
-        outlier: bool = False,
     ) -> List[EventBase]:
         """Takes a list of PDUs and checks the signatures and hashes of each
         one. If a PDU fails its signature check then we check if we have it in
@@ -451,7 +437,6 @@ class FederationClient(FederationBase):
             origin
             pdu
             room_version
-            outlier: Whether the events are outliers or not
 
         Returns:
             A list of PDUs that have valid signatures and hashes.
@@ -466,7 +451,6 @@ class FederationClient(FederationBase):
             valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
                 pdu=pdu,
                 origin=origin,
-                outlier=outlier,
                 room_version=room_version,
             )
 
@@ -482,7 +466,6 @@ class FederationClient(FederationBase):
         pdu: EventBase,
         origin: str,
         room_version: RoomVersion,
-        outlier: bool = False,
     ) -> Optional[EventBase]:
         """Takes a PDU and checks its signatures and hashes. If the PDU fails
         its signature check then we check if we have it in the database and if
@@ -494,9 +477,6 @@ class FederationClient(FederationBase):
             origin
             pdu
             room_version
-            outlier: Whether the events are outliers or not
-            include_none: Whether to include None in the returned list
-                for events that have failed their checks
 
         Returns:
             The PDU (possibly redacted) if it has valid signatures and hashes.
@@ -521,7 +501,6 @@ class FederationClient(FederationBase):
                     destinations=[pdu_origin],
                     event_id=pdu.event_id,
                     room_version=room_version,
-                    outlier=outlier,
                     timeout=10000,
                 )
             except SynapseError:
@@ -541,13 +520,10 @@ class FederationClient(FederationBase):
 
         room_version = await self.store.get_room_version(room_id)
 
-        auth_chain = [
-            event_from_pdu_json(p, room_version, outlier=True)
-            for p in res["auth_chain"]
-        ]
+        auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]]
 
         signed_auth = await self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True, room_version=room_version
+            destination, auth_chain, room_version=room_version
         )
 
         return signed_auth
@@ -816,7 +792,6 @@ class FederationClient(FederationBase):
                 valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
                     pdu=event,
                     origin=destination,
-                    outlier=True,
                     room_version=room_version,
                 )
 
@@ -864,7 +839,6 @@ class FederationClient(FederationBase):
                 valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
                     pdu=pdu,
                     origin=destination,
-                    outlier=True,
                     room_version=room_version,
                 )
 
@@ -1235,7 +1209,7 @@ class FederationClient(FederationBase):
             ]
 
             signed_events = await self._check_sigs_and_hash_and_fetch(
-                destination, events, outlier=False, room_version=room_version
+                destination, events, room_version=room_version
             )
         except HttpResponseException as e:
             if not e.code == 400:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8e37e76206..ee71f289c8 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -28,9 +28,9 @@ from typing import (
     Union,
 )
 
+from matrix_common.regex import glob_to_regex
 from prometheus_client import Counter, Gauge, Histogram
 
-from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
 from twisted.python import failure
 
@@ -66,8 +66,8 @@ from synapse.replication.http.federation import (
 )
 from synapse.storage.databases.main.lock import Lock
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util import json_decoder, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import parse_server_name
 
@@ -360,13 +360,13 @@ class FederationServer(FederationBase):
         # want to block things like to device messages from reaching clients
         # behind the potentially expensive handling of PDUs.
         pdu_results, _ = await make_deferred_yieldable(
-            defer.gatherResults(
-                [
+            gather_results(
+                (
                     run_in_background(
                         self._handle_pdus_in_txn, origin, transaction, request_time
                     ),
                     run_in_background(self._handle_edus_in_txn, origin, transaction),
-                ],
+                ),
                 consumeErrors=True,
             ).addErrback(unwrapFirstError)
         )
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 63289a5a33..0d7c4f5067 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -30,7 +30,6 @@ Events are replicated via a separate events stream.
 """
 
 import logging
-from collections import namedtuple
 from typing import (
     TYPE_CHECKING,
     Dict,
@@ -43,6 +42,7 @@ from typing import (
     Type,
 )
 
+import attr
 from sortedcontainers import SortedDict
 
 from synapse.api.presence import UserPresenceState
@@ -382,13 +382,11 @@ class BaseFederationRow:
         raise NotImplementedError()
 
 
-class PresenceDestinationsRow(
-    BaseFederationRow,
-    namedtuple(
-        "PresenceDestinationsRow",
-        ("state", "destinations"),  # UserPresenceState  # list[str]
-    ),
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PresenceDestinationsRow(BaseFederationRow):
+    state: UserPresenceState
+    destinations: List[str]
+
     TypeId = "pd"
 
     @staticmethod
@@ -404,17 +402,15 @@ class PresenceDestinationsRow(
         buff.presence_destinations.append((self.state, self.destinations))
 
 
-class KeyedEduRow(
-    BaseFederationRow,
-    namedtuple(
-        "KeyedEduRow",
-        ("key", "edu"),  # tuple(str) - the edu key passed to send_edu  # Edu
-    ),
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class KeyedEduRow(BaseFederationRow):
     """Streams EDUs that have an associated key that is ued to clobber. For example,
     typing EDUs clobber based on room_id.
     """
 
+    key: Tuple[str, ...]  # the edu key passed to send_edu
+    edu: Edu
+
     TypeId = "k"
 
     @staticmethod
@@ -428,9 +424,12 @@ class KeyedEduRow(
         buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
 
 
-class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EduRow(BaseFederationRow):
     """Streams EDUs that don't have keys. See KeyedEduRow"""
 
+    edu: Edu
+
     TypeId = "e"
 
     @staticmethod
@@ -453,14 +452,14 @@ _rowtypes: Tuple[Type[BaseFederationRow], ...] = (
 TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
 
 
-ParsedFederationStreamData = namedtuple(
-    "ParsedFederationStreamData",
-    (
-        "presence_destinations",  # list of tuples of UserPresenceState and destinations
-        "keyed_edus",  # dict of destination -> { key -> Edu }
-        "edus",  # dict of destination -> [Edu]
-    ),
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ParsedFederationStreamData:
+    # list of tuples of UserPresenceState and destinations
+    presence_destinations: List[Tuple[UserPresenceState, List[str]]]
+    # dict of destination -> { key -> Edu }
+    keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]]
+    # dict of destination -> [Edu]
+    edus: Dict[str, List[Edu]]
 
 
 def process_rows_for_federation(
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index dc39e3537b..da1fbf8b63 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -22,13 +22,11 @@ from synapse.api.urls import FEDERATION_V1_PREFIX
 from synapse.http.server import HttpServer, ServletCallback
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging import opentracing
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
-    SynapseTags,
-    start_active_span,
-    start_active_span_from_request,
-    tags,
+    set_tag,
+    span_context_from_request,
+    start_active_span_follows_from,
     whitelisted_homeserver,
 )
 from synapse.server import HomeServer
@@ -279,30 +277,19 @@ class BaseFederationServlet:
                 logger.warning("authenticate_request failed: %s", e)
                 raise
 
-            request_tags = {
-                SynapseTags.REQUEST_ID: request.get_request_id(),
-                tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
-                tags.HTTP_METHOD: request.get_method(),
-                tags.HTTP_URL: request.get_redacted_uri(),
-                tags.PEER_HOST_IPV6: request.getClientIP(),
-                "authenticated_entity": origin,
-                "servlet_name": request.request_metrics.name,
-            }
-
-            # Only accept the span context if the origin is authenticated
-            # and whitelisted
+            # update the active opentracing span with the authenticated entity
+            set_tag("authenticated_entity", origin)
+
+            # if the origin is authenticated and whitelisted, link to its span context
+            context = None
             if origin and whitelisted_homeserver(origin):
-                scope = start_active_span_from_request(
-                    request, "incoming-federation-request", tags=request_tags
-                )
-            else:
-                scope = start_active_span(
-                    "incoming-federation-request", tags=request_tags
-                )
+                context = span_context_from_request(request)
 
-            with scope:
-                opentracing.inject_response_headers(request.responseHeaders)
+            scope = start_active_span_follows_from(
+                "incoming-federation-request", contexts=(context,) if context else ()
+            )
 
+            with scope:
                 if origin and self.RATELIMIT:
                     with ratelimiter.ratelimit(origin) as d:
                         await d
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9abdad262b..7833e77e2b 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -462,9 +462,9 @@ class ApplicationServicesHandler:
 
         Args:
             room_alias: The room alias to query.
+
         Returns:
-            namedtuple: with keys "room_id" and "servers" or None if no
-            association can be found.
+            RoomAliasMapping or None if no association can be found.
         """
         room_alias_str = room_alias.to_string()
         services = self.store.get_app_services()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 61607cf2ba..84724b207c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -997,9 +997,7 @@ class AuthHandler:
         # really don't want is active access_tokens without a record of the
         # device, so we double-check it here.
         if device_id is not None:
-            try:
-                await self.store.get_device(user_id, device_id)
-            except StoreError:
+            if await self.store.get_device(user_id, device_id) is None:
                 await self.store.delete_access_token(access_token)
                 raise StoreError(400, "Login raced against device deletion")
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 82ee11e921..7665425232 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -106,10 +106,10 @@ class DeviceWorkerHandler:
         Raises:
             errors.NotFoundError: if the device was not found
         """
-        try:
-            device = await self.store.get_device(user_id, device_id)
-        except errors.StoreError:
-            raise errors.NotFoundError
+        device = await self.store.get_device(user_id, device_id)
+        if device is None:
+            raise errors.NotFoundError()
+
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
@@ -602,6 +602,8 @@ class DeviceHandler(DeviceWorkerHandler):
             access_token, device_id
         )
         old_device = await self.store.get_device(user_id, old_device_id)
+        if old_device is None:
+            raise errors.NotFoundError()
         await self.store.update_device(user_id, device_id, old_device["display_name"])
         # can't call self.delete_device because that will clobber the
         # access token so call the storage layer directly
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7ee5c47fd9..082f521791 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -278,13 +278,15 @@ class DirectoryHandler:
 
         users = await self.store.get_users_in_room(room_id)
         extra_servers = {get_domain_from_id(u) for u in users}
-        servers = set(extra_servers) | set(servers)
+        servers_set = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
-        if self.server_name in servers:
-            servers = [self.server_name] + [s for s in servers if s != self.server_name]
+        if self.server_name in servers_set:
+            servers = [self.server_name] + [
+                s for s in servers_set if s != self.server_name
+            ]
         else:
-            servers = list(servers)
+            servers = list(servers_set)
 
         return {"room_id": room_id, "servers": servers}
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 60c11e3d21..14360b4e40 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -65,8 +65,12 @@ class E2eKeysHandler:
         else:
             # Only register this edu handler on master as it requires writing
             # device updates to the db
-            #
-            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            federation_registry.register_edu_handler(
+                "m.signing_key_update",
+                self._edu_updater.incoming_signing_key_update,
+            )
+            # also handle the unstable version
+            # FIXME: remove this when enough servers have upgraded
             federation_registry.register_edu_handler(
                 "org.matrix.signing_key_update",
                 self._edu_updater.incoming_signing_key_update,
@@ -576,7 +580,9 @@ class E2eKeysHandler:
             log_kv(
                 {"message": "Did not update one_time_keys", "reason": "no keys given"}
             )
-        fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+        fallback_keys = keys.get("fallback_keys") or keys.get(
+            "org.matrix.msc2732.fallback_keys"
+        )
         if fallback_keys and isinstance(fallback_keys, dict):
             log_kv(
                 {
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 31742236a9..12614b2c5d 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,9 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING, Dict, Optional
+
+from typing_extensions import Literal
 
 from synapse.api.errors import (
     Codes,
@@ -24,6 +26,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, trace
+from synapse.storage.databases.main.e2e_room_keys import RoomKey
 from synapse.types import JsonDict
 from synapse.util.async_helpers import Linearizer
 
@@ -58,7 +61,9 @@ class E2eRoomKeysHandler:
         version: str,
         room_id: Optional[str] = None,
         session_id: Optional[str] = None,
-    ) -> List[JsonDict]:
+    ) -> Dict[
+        Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+    ]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
         See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@@ -72,8 +77,8 @@ class E2eRoomKeysHandler:
         Raises:
             NotFoundError: if the backup version does not exist
         Returns:
-            A list of dicts giving the session_data and message metadata for
-            these room keys.
+            A dict giving the session_data and message metadata for these room keys.
+            `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
         """
 
         # we deliberately take the lock to get keys so that changing the version
@@ -273,7 +278,7 @@ class E2eRoomKeysHandler:
 
     @staticmethod
     def _should_replace_room_key(
-        current_room_key: Optional[JsonDict], room_key: JsonDict
+        current_room_key: Optional[RoomKey], room_key: RoomKey
     ) -> bool:
         """
         Determine whether to replace a given current_room_key (if any)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 32b0254c5f..1b996c420d 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -79,13 +79,14 @@ class EventStreamHandler:
                 # thundering herds on restart.
                 timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
 
-            events, tokens = await self.notifier.get_events_for(
+            stream_result = await self.notifier.get_events_for(
                 auth_user,
                 pagin_config,
                 timeout,
                 is_guest=is_guest,
                 explicit_room_id=room_id,
             )
+            events = stream_result.events
 
             time_now = self.clock.time_msec()
 
@@ -122,14 +123,12 @@ class EventStreamHandler:
                 events,
                 time_now,
                 as_client_event=as_client_event,
-                # Don't bundle aggregations as this is a deprecated API.
-                bundle_aggregations=False,
             )
 
             chunk = {
                 "chunk": chunks,
-                "start": await tokens[0].to_string(self.store),
-                "end": await tokens[1].to_string(self.store),
+                "start": await stream_result.start_token.to_string(self.store),
+                "end": await stream_result.end_token.to_string(self.store),
             }
 
             return chunk
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1ea837d082..26b8e3f43c 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -360,31 +360,34 @@ class FederationHandler:
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
-        states = await make_deferred_yieldable(
+        states_list = await make_deferred_yieldable(
             defer.gatherResults(
                 [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
             )
         )
 
-        # dict[str, dict[tuple, str]], a map from event_id to state map of
-        # event_ids.
-        states = dict(zip(event_ids, [s.state for s in states]))
+        # A map from event_id to state map of event_ids.
+        state_ids: Dict[str, StateMap[str]] = dict(
+            zip(event_ids, [s.state for s in states_list])
+        )
 
         state_map = await self.store.get_events(
-            [e_id for ids in states.values() for e_id in ids.values()],
+            [e_id for ids in state_ids.values() for e_id in ids.values()],
             get_prev_content=False,
         )
-        states = {
+
+        # A map from event_id to state map of events.
+        state_events: Dict[str, StateMap[EventBase]] = {
             key: {
                 k: state_map[e_id]
                 for k, e_id in state_dict.items()
                 if e_id in state_map
             }
-            for key, state_dict in states.items()
+            for key, state_dict in state_ids.items()
         }
 
         for e_id in event_ids:
-            likely_extremeties_domains = get_domains_from_state(states[e_id])
+            likely_extremeties_domains = get_domains_from_state(state_events[e_id])
 
             success = await try_backfill(
                 [
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9917613298..11771f3c9c 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -421,9 +421,6 @@ class FederationEventHandler:
         Raises:
             SynapseError if the response is in some way invalid.
         """
-        for e in itertools.chain(auth_events, state):
-            e.internal_metadata.outlier = True
-
         event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
 
         create_event = None
@@ -666,7 +663,9 @@ class FederationEventHandler:
         logger.info("Processing pulled event %s", event)
 
         # these should not be outliers.
-        assert not event.internal_metadata.is_outlier()
+        assert (
+            not event.internal_metadata.is_outlier()
+        ), "pulled event unexpectedly flagged as outlier"
 
         event_id = event.event_id
 
@@ -1192,7 +1191,6 @@ class FederationEventHandler:
                         [destination],
                         event_id,
                         room_version,
-                        outlier=True,
                     )
                     if event is None:
                         logger.warning(
@@ -1221,9 +1219,10 @@ class FederationEventHandler:
         """Persist a batch of outlier events fetched from remote servers.
 
         We first sort the events to make sure that we process each event's auth_events
-        before the event itself, and then auth and persist them.
+        before the event itself.
 
-        Notifies about the events where appropriate.
+        We then mark the events as outliers, persist them to the database, and, where
+        appropriate (eg, an invite), awake the notifier.
 
         Params:
             room_id: the room that the events are meant to be in (though this has
@@ -1274,7 +1273,8 @@ class FederationEventHandler:
         Persists a batch of events where we have (theoretically) already persisted all
         of their auth events.
 
-        Notifies about the events where appropriate.
+        Marks the events as outliers, auths them, persists them to the database, and,
+        where appropriate (eg, an invite), awakes the notifier.
 
         Params:
             origin: where the events came from
@@ -1312,6 +1312,9 @@ class FederationEventHandler:
                         return None
                     auth.append(ae)
 
+                # we're not bothering about room state, so flag the event as an outlier.
+                event.internal_metadata.outlier = True
+
                 context = EventContext.for_outlier()
                 try:
                     validate_event_for_room_version(room_version_obj, event)
@@ -1838,7 +1841,7 @@ class FederationEventHandler:
             The stream ID after which all events have been persisted.
         """
         if not event_and_contexts:
-            return self._store.get_current_events_token()
+            return self._store.get_room_max_stream_ordering()
 
         instance = self._config.worker.events_shard_config.get_instance(room_id)
         if instance != self._instance_name:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 9cd21e7f2b..601bab67f9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -13,21 +13,27 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import SynapseError
+from synapse.events import EventBase
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
 from synapse.handlers.receipts import ReceiptEventSource
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+    JsonDict,
+    Requester,
+    RoomStreamToken,
+    StateMap,
+    StreamToken,
+    UserID,
+)
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import concurrently_execute
+from synapse.util.async_helpers import concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
@@ -167,8 +173,6 @@ class InitialSyncHandler:
                 d["invite"] = await self._event_serializer.serialize_event(
                     invite_event,
                     time_now,
-                    # Don't bundle aggregations as this is a deprecated API.
-                    bundle_aggregations=False,
                     as_client_event=as_client_event,
                 )
 
@@ -190,14 +194,13 @@ class InitialSyncHandler:
                     )
                     deferred_room_state = run_in_background(
                         self.state_store.get_state_for_events, [event.event_id]
-                    )
-                    deferred_room_state.addCallback(
-                        lambda states: states[event.event_id]
+                    ).addCallback(
+                        lambda states: cast(StateMap[EventBase], states[event.event_id])
                     )
 
                 (messages, token), current_state = await make_deferred_yieldable(
-                    defer.gatherResults(
-                        [
+                    gather_results(
+                        (
                             run_in_background(
                                 self.store.get_recent_events_for_room,
                                 event.room_id,
@@ -205,7 +208,7 @@ class InitialSyncHandler:
                                 end_token=room_end_token,
                             ),
                             deferred_room_state,
-                        ]
+                        )
                     )
                 ).addErrback(unwrapFirstError)
 
@@ -222,8 +225,6 @@ class InitialSyncHandler:
                         await self._event_serializer.serialize_events(
                             messages,
                             time_now=time_now,
-                            # Don't bundle aggregations as this is a deprecated API.
-                            bundle_aggregations=False,
                             as_client_event=as_client_event,
                         )
                     ),
@@ -234,8 +235,6 @@ class InitialSyncHandler:
                 d["state"] = await self._event_serializer.serialize_events(
                     current_state.values(),
                     time_now=time_now,
-                    # Don't bundle aggregations as this is a deprecated API.
-                    bundle_aggregations=False,
                     as_client_event=as_client_event,
                 )
 
@@ -377,9 +376,7 @@ class InitialSyncHandler:
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    await self._event_serializer.serialize_events(
-                        messages, time_now, bundle_aggregations=False
-                    )
+                    await self._event_serializer.serialize_events(messages, time_now)
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
@@ -387,7 +384,7 @@ class InitialSyncHandler:
             "state": (
                 # Don't bundle aggregations as this is a deprecated API.
                 await self._event_serializer.serialize_events(
-                    room_state.values(), time_now, bundle_aggregations=False
+                    room_state.values(), time_now
                 )
             ),
             "presence": [],
@@ -408,7 +405,7 @@ class InitialSyncHandler:
         time_now = self.clock.time_msec()
         # Don't bundle aggregations as this is a deprecated API.
         state = await self._event_serializer.serialize_events(
-            current_state.values(), time_now, bundle_aggregations=False
+            current_state.values(), time_now
         )
 
         now_token = self.hs.get_event_sources().get_current_token()
@@ -454,8 +451,8 @@ class InitialSyncHandler:
             return receipts
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
-            defer.gatherResults(
-                [
+            gather_results(
+                (
                     run_in_background(get_presence),
                     run_in_background(get_receipts),
                     run_in_background(
@@ -464,7 +461,7 @@ class InitialSyncHandler:
                         limit=limit,
                         end_token=now_token.room_key,
                     ),
-                ],
+                ),
                 consumeErrors=True,
             ).addErrback(unwrapFirstError)
         )
@@ -483,9 +480,7 @@ class InitialSyncHandler:
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    await self._event_serializer.serialize_events(
-                        messages, time_now, bundle_aggregations=False
-                    )
+                    await self._event_serializer.serialize_events(messages, time_now)
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0b41dd38ef..d3e8303b83 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
-from twisted.internet import defer
 from twisted.internet.interfaces import IDelayedCall
 
 from synapse import event_auth
@@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
 from synapse.util import json_decoder, json_encoder, log_failure
-from synapse.util.async_helpers import Linearizer, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
@@ -496,6 +495,7 @@ class EventCreationHandler:
         require_consent: bool = True,
         outlier: bool = False,
         historical: bool = False,
+        allow_no_prev_events: bool = False,
         depth: Optional[int] = None,
     ) -> Tuple[EventBase, EventContext]:
         """
@@ -607,6 +607,7 @@ class EventCreationHandler:
             prev_event_ids=prev_event_ids,
             auth_event_ids=auth_event_ids,
             depth=depth,
+            allow_no_prev_events=allow_no_prev_events,
         )
 
         # In an ideal world we wouldn't need the second part of this condition. However,
@@ -882,6 +883,7 @@ class EventCreationHandler:
         prev_event_ids: Optional[List[str]] = None,
         auth_event_ids: Optional[List[str]] = None,
         depth: Optional[int] = None,
+        allow_no_prev_events: bool = False,
     ) -> Tuple[EventBase, EventContext]:
         """Create a new event for a local client
 
@@ -912,6 +914,7 @@ class EventCreationHandler:
         full_state_ids_at_event = None
         if auth_event_ids is not None:
             # If auth events are provided, prev events must be also.
+            # prev_event_ids could be an empty array though.
             assert prev_event_ids is not None
 
             # Copy the full auth state before it stripped down
@@ -943,14 +946,22 @@ class EventCreationHandler:
         else:
             prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
 
-        # we now ought to have some prev_events (unless it's a create event).
-        #
-        # do a quick sanity check here, rather than waiting until we've created the
+        # Do a quick sanity check here, rather than waiting until we've created the
         # event and then try to auth it (which fails with a somewhat confusing "No
         # create event in auth events")
-        assert (
-            builder.type == EventTypes.Create or len(prev_event_ids) > 0
-        ), "Attempting to create an event with no prev_events"
+        if allow_no_prev_events:
+            # We allow events with no `prev_events` but it better have some `auth_events`
+            assert (
+                builder.type == EventTypes.Create
+                # Allow an event to have empty list of prev_event_ids
+                # only if it has auth_event_ids.
+                or auth_event_ids
+            ), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids"
+        else:
+            # we now ought to have some prev_events (unless it's a create event).
+            assert (
+                builder.type == EventTypes.Create or prev_event_ids
+            ), "Attempting to create a non-m.room.create event with no prev_events"
 
         event = await builder.build(
             prev_event_ids=prev_event_ids,
@@ -1156,9 +1167,9 @@ class EventCreationHandler:
 
         # We now persist the event (and update the cache in parallel, since we
         # don't want to block on it).
-        result = await make_deferred_yieldable(
-            defer.gatherResults(
-                [
+        result, _ = await make_deferred_yieldable(
+            gather_results(
+                (
                     run_in_background(
                         self._persist_event,
                         requester=requester,
@@ -1170,12 +1181,12 @@ class EventCreationHandler:
                     run_in_background(
                         self.cache_joined_hosts_for_event, event, context
                     ).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
-                ],
+                ),
                 consumeErrors=True,
             )
         ).addErrback(unwrapFirstError)
 
-        return result[0]
+        return result
 
     async def _persist_event(
         self,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 4f42438053..7469cc55a2 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -542,7 +542,10 @@ class PaginationHandler:
         chunk = {
             "chunk": (
                 await self._event_serializer.serialize_events(
-                    events, time_now, as_client_event=as_client_event
+                    events,
+                    time_now,
+                    bundle_aggregations=True,
+                    as_client_event=as_client_event,
                 )
             ),
             "start": await from_token.to_string(self.store),
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 454d06c973..c781fefb1b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -729,7 +729,7 @@ class PresenceHandler(BasePresenceHandler):
 
         # Presence is best effort and quickly heals itself, so lets just always
         # stream from the current state when we restart.
-        self._event_pos = self.store.get_current_events_token()
+        self._event_pos = self.store.get_room_max_stream_ordering()
         self._event_processing = False
 
     async def _on_shutdown(self) -> None:
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4911a11535..5cb1ff749d 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,7 +14,7 @@
 import logging
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
 from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
             for event_id in content.keys():
                 event_content = content.get(event_id, {})
-                m_read = event_content.get("m.read", {})
+                m_read = event_content.get(ReceiptTypes.READ, {})
 
                 # If m_read is missing copy over the original event_content as there is nothing to process here
                 if not m_read:
@@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
                 # Set new users unless empty
                 if len(new_users.keys()) > 0:
-                    new_event["content"][event_id] = {"m.read": new_users}
+                    new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
 
             # Append new_event to visible_events unless empty
             if len(new_event["content"].keys()) > 0:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ead2198e14..b9c1cbffa5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -172,7 +172,7 @@ class RoomCreationHandler:
         user_id = requester.user.to_string()
 
         # Check if this room is already being upgraded by another person
-        for key in self._upgrade_response_cache.pending_result_cache:
+        for key in self._upgrade_response_cache.keys():
             if key[0] == old_room_id and key[1] != user_id:
                 # Two different people are trying to upgrade the same room.
                 # Send the second an error.
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index ba7a14d651..1a33211a1f 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -13,9 +13,9 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 from typing import TYPE_CHECKING, Any, Optional, Tuple
 
+import attr
 import msgpack
 from unpaddedbase64 import decode_base64, encode_base64
 
@@ -474,16 +474,12 @@ class RoomListHandler:
         )
 
 
-class RoomListNextBatch(
-    namedtuple(
-        "RoomListNextBatch",
-        (
-            "last_joined_members",  # The count to get rooms after/before
-            "last_room_id",  # The room_id to get rooms after/before
-            "direction_is_forward",  # Bool if this is a next_batch, false if prev_batch
-        ),
-    )
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomListNextBatch:
+    last_joined_members: int  # The count to get rooms after/before
+    last_room_id: str  # The room_id to get rooms after/before
+    direction_is_forward: bool  # True if this is a next_batch, false if prev_batch
+
     KEY_DICT = {
         "last_joined_members": "m",
         "last_room_id": "r",
@@ -502,12 +498,12 @@ class RoomListNextBatch(
     def to_token(self) -> str:
         return encode_base64(
             msgpack.dumps(
-                {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
+                {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}
             )
         )
 
     def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
-        return self._replace(**kwds)
+        return attr.evolve(self, **kwds)
 
 
 def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index cac76d0221..27e2903a8f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -678,7 +678,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if block_invite:
                 raise SynapseError(403, "Invites have been disabled on this server")
 
-        if prev_event_ids:
+        # An empty prev_events list is allowed as long as the auth_event_ids are present
+        if prev_event_ids is not None:
             return await self._local_membership_update(
                 requester=requester,
                 target=target,
@@ -1039,7 +1040,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # Add new room to the room directory if the old room was there
         # Remove old room from the room directory
         old_room = await self.store.get_room(old_room_id)
-        if old_room and old_room["is_public"]:
+        if old_room is not None and old_room["is_public"]:
             await self.store.set_room_is_public(old_room_id, False)
             await self.store.set_room_is_public(room_id, True)
 
@@ -1050,7 +1051,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
         for group_id in local_group_ids:
             # Add new the new room to those groups
-            await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
+            await self.store.add_room_to_group(
+                group_id, room_id, old_room is not None and old_room["is_public"]
+            )
 
             # Remove the old room from those groups
             await self.store.remove_room_from_group(group_id, old_room_id)
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index bd3e6f2ec7..29e41a4c79 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -80,6 +80,17 @@ class StatsHandler:
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
             self.pos = await self.store.get_stats_positions()
+            room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+            if self.pos > room_max_stream_ordering:
+                # apparently, we've processed more events than exist in the database!
+                # this can happen if events are removed with history purge or similar.
+                logger.warning(
+                    "Event stream ordering appears to have gone backwards (%i -> %i): "
+                    "rewinding stats processor",
+                    self.pos,
+                    room_max_stream_ordering,
+                )
+                self.pos = room_max_stream_ordering
 
         # Loop round handling deltas until we're up to date
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f3039c3c3f..7baf3f199c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,7 +28,7 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -36,6 +36,7 @@ from synapse.events import EventBase
 from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
+from synapse.storage.databases.main.event_push_actions import NotifCounts
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -421,7 +422,7 @@ class SyncHandler:
         span to track the sync. See `generate_sync_result` for the next part of your
         indoctrination.
         """
-        with start_active_span("current_sync_for_user"):
+        with start_active_span("sync.current_sync_for_user"):
             log_kv({"since_token": since_token})
             sync_result = await self.generate_sync_result(
                 sync_config, since_token, full_state
@@ -1041,18 +1042,17 @@ class SyncHandler:
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> Dict[str, int]:
+    ) -> NotifCounts:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
                 room_id=room_id,
-                receipt_type="m.read",
+                receipt_type=ReceiptTypes.READ,
             )
 
-            notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+            return await self.store.get_unread_event_push_actions_by_room_for_user(
                 room_id, sync_config.user.to_string(), last_unread_event_id
             )
-            return notifs
 
     async def generate_sync_result(
         self,
@@ -1585,7 +1585,8 @@ class SyncHandler:
             )
             logger.debug("Generated room entry for %s", room_entry.room_id)
 
-        await concurrently_execute(handle_room_entries, room_entries, 10)
+        with start_active_span("sync.generate_room_entries"):
+            await concurrently_execute(handle_room_entries, room_entries, 10)
 
         sync_result_builder.invited.extend(invited)
         sync_result_builder.knocked.extend(knocked)
@@ -1662,20 +1663,20 @@ class SyncHandler:
     ) -> _RoomChanges:
         """Determine the changes in rooms to report to the user.
 
-        Ideally, we want to report all events whose stream ordering `s` lies in the
-        range `since_token < s <= now_token`, where the two tokens are read from the
-        sync_result_builder.
+        This function is a first pass at generating the rooms part of the sync response.
+        It determines which rooms have changed during the sync period, and categorises
+        them into four buckets: "knock", "invite", "join" and "leave".
 
-        If there are too many events in that range to report, things get complicated.
-        In this situation we return a truncated list of the most recent events, and
-        indicate in the response that there is a "gap" of omitted events. Additionally:
+        1. Finds all membership changes for the user in the sync period (from
+           `since_token` up to `now_token`).
+        2. Uses those to place the room in one of the four categories above.
+        3. Builds a `_RoomChanges` struct to record this, and return that struct.
 
-        - we include a "state_delta", to describe the changes in state over the gap,
-        - we include all membership events applying to the user making the request,
-          even those in the gap.
-
-        See the spec for the rationale:
-            https://spec.matrix.org/v1.1/client-server-api/#syncing
+        For rooms classified as "knock", "invite" or "leave", we just need to report
+        a single membership event in the eventual /sync response. For "join" we need
+        to fetch additional non-membership events, e.g. messages in the room. That is
+        more complicated, so instead we report an intermediary `RoomSyncResultBuilder`
+        struct, and leave the additional work to `_generate_room_entry`.
 
         The sync_result_builder is not modified by this function.
         """
@@ -1686,16 +1687,6 @@ class SyncHandler:
 
         assert since_token
 
-        # The spec
-        #     https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
-        # notes that membership events need special consideration:
-        #
-        # > When a sync is limited, the server MUST return membership events for events
-        # > in the gap (between since and the start of the returned timeline), regardless
-        # > as to whether or not they are redundant.
-        #
-        # We fetch such events here, but we only seem to use them for categorising rooms
-        # as newly joined, newly left, invited or knocked.
         # TODO: we've already called this function and ran this query in
         #       _have_rooms_changed. We could keep the results in memory to avoid a
         #       second query, at the cost of more complicated source code.
@@ -2009,6 +2000,23 @@ class SyncHandler:
         """Populates the `joined` and `archived` section of `sync_result_builder`
         based on the `room_builder`.
 
+        Ideally, we want to report all events whose stream ordering `s` lies in the
+        range `since_token < s <= now_token`, where the two tokens are read from the
+        sync_result_builder.
+
+        If there are too many events in that range to report, things get complicated.
+        In this situation we return a truncated list of the most recent events, and
+        indicate in the response that there is a "gap" of omitted events. Lots of this
+        is handled in `_load_filtered_recents`, but some of is handled in this method.
+
+        Additionally:
+        - we include a "state_delta", to describe the changes in state over the gap,
+        - we include all membership events applying to the user making the request,
+          even those in the gap.
+
+        See the spec for the rationale:
+            https://spec.matrix.org/v1.1/client-server-api/#syncing
+
         Args:
             sync_result_builder
             ignored_users: Set of users ignored by user.
@@ -2038,7 +2046,7 @@ class SyncHandler:
         since_token = room_builder.since_token
         upto_token = room_builder.upto_token
 
-        with start_active_span("generate_room_entry"):
+        with start_active_span("sync.generate_room_entry"):
             set_tag("room_id", room_id)
             log_kv({"events": len(events or ())})
 
@@ -2166,10 +2174,10 @@ class SyncHandler:
                 if room_sync or always_include:
                     notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                    unread_notifications["notification_count"] = notifs["notify_count"]
-                    unread_notifications["highlight_count"] = notifs["highlight_count"]
+                    unread_notifications["notification_count"] = notifs.notify_count
+                    unread_notifications["highlight_count"] = notifs.highlight_count
 
-                    room_sync.unread_count = notifs["unread_count"]
+                    room_sync.unread_count = notifs.unread_count
 
                     sync_result_builder.joined.append(room_sync)
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 1676ebd057..e43c22832d 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -13,9 +13,10 @@
 # limitations under the License.
 import logging
 import random
-from collections import namedtuple
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
+import attr
+
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import (
@@ -37,7 +38,10 @@ logger = logging.getLogger(__name__)
 
 # A tiny object useful for storing a user's membership in a room, as a mapping
 # key
-RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomMember:
+    room_id: str
+    user_id: str
 
 
 # How often we expect remote servers to resend us presence.
@@ -119,7 +123,7 @@ class FollowerTypingHandler:
         self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
     def is_typing(self, member: RoomMember) -> bool:
-        return member.user_id in self._room_typing.get(member.room_id, [])
+        return member.user_id in self._room_typing.get(member.room_id, set())
 
     async def _push_remote(self, member: RoomMember, typing: bool) -> None:
         if not self.federation:
@@ -166,9 +170,9 @@ class FollowerTypingHandler:
         for row in rows:
             self._room_serials[row.room_id] = token
 
-            prev_typing = set(self._room_typing.get(row.room_id, []))
+            prev_typing = self._room_typing.get(row.room_id, set())
             now_typing = set(row.user_ids)
-            self._room_typing[row.room_id] = row.user_ids
+            self._room_typing[row.room_id] = now_typing
 
             if self.federation:
                 run_as_background_process(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a0eb45446f..1565e034cb 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -148,9 +148,21 @@ class UserDirectoryHandler(StateDeltasHandler):
         if self.pos is None:
             self.pos = await self.store.get_user_directory_stream_pos()
 
-        # If still None then the initial background update hasn't happened yet.
-        if self.pos is None:
-            return None
+            # If still None then the initial background update hasn't happened yet.
+            if self.pos is None:
+                return None
+
+            room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+            if self.pos > room_max_stream_ordering:
+                # apparently, we've processed more events than exist in the database!
+                # this can happen if events are removed with history purge or similar.
+                logger.warning(
+                    "Event stream ordering appears to have gone backwards (%i -> %i): "
+                    "rewinding user directory processor",
+                    self.pos,
+                    room_max_stream_ordering,
+                )
+                self.pos = room_max_stream_ordering
 
         # Loop round handling deltas until we're up to date
         while True:
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 578fc48ef4..efecb089c1 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError
 class RequestTimedOutError(SynapseError):
     """Exception representing timeout of an outbound request"""
 
-    def __init__(self, msg):
+    def __init__(self, msg: str):
         super().__init__(504, msg)
 
 
@@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
 CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
 
 
-def redact_uri(uri):
+def redact_uri(uri: str) -> str:
     """Strips sensitive information from the uri replaces with <redacted>"""
     uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
     return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
@@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
     https://twistedmatrix.com/trac/ticket/6528
     """
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         try:
             FileBodyProducer.stopProducing(self)
         except task.TaskStopped:
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 9a2684aca4..6a9f6635d2 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
 
 from twisted.web.server import Request
 
@@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
     and exception handling.
     """
 
-    def __init__(self, hs: "HomeServer", handler):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
+    ):
         """Initialise AdditionalResource
 
         The ``handler`` should return a deferred which completes when it has
@@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource):
         super().__init__()
         self._handler = handler
 
-    def _async_render(self, request: Request):
+    async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
         # Cheekily pass the result straight through, so we don't need to worry
         # if its an awaitable or not.
-        return self._handler(request)
+        return await self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index b5a2d333a6..ca33b45cb2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 import logging
 import urllib.parse
+from http import HTTPStatus
 from io import BytesIO
 from typing import (
     TYPE_CHECKING,
@@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent):
                 ip_address, self._ip_whitelist, self._ip_blacklist
             ):
                 logger.info("Blocking access to %s due to blacklist" % (ip_address,))
-                e = SynapseError(403, "IP address blocked by IP blacklist entry")
+                e = SynapseError(
+                    HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
+                )
                 return defer.fail(Failure(e))
 
         return self._agent.request(
@@ -585,7 +588,7 @@ class SimpleHttpClient:
         if headers:
             actual_headers.update(headers)  # type: ignore
 
-        body = await self.get_raw(uri, args, headers=headers)
+        body = await self.get_raw(uri, args, headers=actual_headers)
         return json_decoder.decode(body.decode("utf-8"))
 
     async def put_json(
@@ -719,7 +722,9 @@ class SimpleHttpClient:
 
         if response.code > 299:
             logger.warning("Got %d when downloading %s" % (response.code, url))
-            raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
+            raise SynapseError(
+                HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
+            )
 
         # TODO: if our Content-Type is HTML or something, just read the first
         # N bytes into RAM rather than saving it all to disk only to read it
@@ -731,12 +736,14 @@ class SimpleHttpClient:
             )
         except BodyExceededMaxSize:
             raise SynapseError(
-                502,
+                HTTPStatus.BAD_GATEWAY,
                 "Requested file is too large > %r bytes" % (max_size,),
                 Codes.TOO_LARGE,
             )
         except Exception as e:
-            raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
+            raise SynapseError(
+                HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
+            ) from e
 
         return (
             length,
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 1238bfd287..a8a520f809 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -25,6 +25,7 @@ from zope.interface import implementer
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.interfaces import (
+    IProtocol,
     IProtocolFactory,
     IReactorCore,
     IStreamClientEndpoint,
@@ -309,12 +310,14 @@ class MatrixHostnameEndpoint:
 
         self._srv_resolver = srv_resolver
 
-    def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
+    def connect(
+        self, protocol_factory: IProtocolFactory
+    ) -> "defer.Deferred[IProtocol]":
         """Implements IStreamClientEndpoint interface"""
 
         return run_in_background(self._do_connect, protocol_factory)
 
-    async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
+    async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
         first_exception = None
 
         server_list = await self._resolve_server()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 203d723d41..deedde0b5b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,6 +19,7 @@ import random
 import sys
 import typing
 import urllib.parse
+from http import HTTPStatus
 from io import BytesIO, StringIO
 from typing import (
     TYPE_CHECKING,
@@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient:
                 request.destination,
                 msg,
             )
-            raise SynapseError(502, msg, Codes.TOO_LARGE)
+            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
         except defer.TimeoutError as e:
             logger.warning(
                 "{%s} [%s] Timed out reading response - %s %s",
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 91badb0b0a..09b4125489 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 import abc
-import collections
 import html
 import logging
 import types
@@ -30,12 +29,14 @@ from typing import (
     Iterable,
     Iterator,
     List,
+    NoReturn,
     Optional,
     Pattern,
     Tuple,
     Union,
 )
 
+import attr
 import jinja2
 from canonicaljson import encode_canonical_json
 from typing_extensions import Protocol
@@ -57,12 +58,14 @@ from synapse.api.errors import (
 )
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
-from synapse.logging.opentracing import trace_servlet
+from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
 from synapse.util import json_encoder
 from synapse.util.caches import intern_dict
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
+    import opentracing
+
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -170,7 +173,9 @@ def return_html_error(
     respond_with_html(request, code, body)
 
 
-def wrap_async_request_handler(h):
+def wrap_async_request_handler(
+    h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
+) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
     """Wraps an async request handler so that it calls request.processing.
 
     This helps ensure that work done by the request handler after the request is completed
@@ -183,7 +188,9 @@ def wrap_async_request_handler(h):
     logged until the deferred completes.
     """
 
-    async def wrapped_async_request_handler(self, request):
+    async def wrapped_async_request_handler(
+        self: "_AsyncResource", request: SynapseRequest
+    ) -> None:
         with request.processing():
             await h(self, request)
 
@@ -240,18 +247,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             context from the request the servlet is handling.
     """
 
-    def __init__(self, extract_context=False):
+    def __init__(self, extract_context: bool = False):
         super().__init__()
 
         self._extract_context = extract_context
 
-    def render(self, request):
+    def render(self, request: SynapseRequest) -> int:
         """This gets called by twisted every time someone sends us a request."""
         defer.ensureDeferred(self._async_render_wrapper(request))
         return NOT_DONE_YET
 
     @wrap_async_request_handler
-    async def _async_render_wrapper(self, request: SynapseRequest):
+    async def _async_render_wrapper(self, request: SynapseRequest) -> None:
         """This is a wrapper that delegates to `_async_render` and handles
         exceptions, return values, metrics, etc.
         """
@@ -271,7 +278,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             f = failure.Failure()
             self._send_error_response(f, request)
 
-    async def _async_render(self, request: Request):
+    async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
         """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
         no appropriate method exists. Can be overridden in sub classes for
         different routing.
@@ -318,7 +325,7 @@ class DirectServeJsonResource(_AsyncResource):
     formatting responses and errors as JSON.
     """
 
-    def __init__(self, canonical_json=False, extract_context=False):
+    def __init__(self, canonical_json: bool = False, extract_context: bool = False):
         super().__init__(extract_context)
         self.canonical_json = canonical_json
 
@@ -327,7 +334,7 @@ class DirectServeJsonResource(_AsyncResource):
         request: SynapseRequest,
         code: int,
         response_object: Any,
-    ):
+    ) -> None:
         """Implements _AsyncResource._send_response"""
         # TODO: Only enable CORS for the requests that need it.
         respond_with_json(
@@ -347,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource):
         return_json_error(f, request)
 
 
-_PathEntry = collections.namedtuple(
-    "_PathEntry", ["pattern", "callback", "servlet_classname"]
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _PathEntry:
+    pattern: Pattern
+    callback: ServletCallback
+    servlet_classname: str
 
 
 class JsonResource(DirectServeJsonResource):
@@ -368,34 +377,45 @@ class JsonResource(DirectServeJsonResource):
 
     isLeaf = True
 
-    def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        canonical_json: bool = True,
+        extract_context: bool = False,
+    ):
         super().__init__(canonical_json, extract_context)
         self.clock = hs.get_clock()
         self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
         self.hs = hs
 
-    def register_paths(self, method, path_patterns, callback, servlet_classname):
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
         """
         Registers a request handler against a regular expression. Later request URLs are
         checked against these regular expressions in order to identify an appropriate
         handler for that request.
 
         Args:
-            method (str): GET, POST etc
+            method: GET, POST etc
 
-            path_patterns (Iterable[str]): A list of regular expressions to which
-                the request URLs are compared.
+            path_patterns: A list of regular expressions to which the request
+                URLs are compared.
 
-            callback (function): The handler for the request. Usually a Servlet
+            callback: The handler for the request. Usually a Servlet
 
-            servlet_classname (str): The name of the handler to be used in prometheus
+            servlet_classname: The name of the handler to be used in prometheus
                 and opentracing logs.
         """
-        method = method.encode("utf-8")  # method is bytes on py3
+        method_bytes = method.encode("utf-8")
 
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
-            self.path_regexs.setdefault(method, []).append(
+            self.path_regexs.setdefault(method_bytes, []).append(
                 _PathEntry(path_pattern, callback, servlet_classname)
             )
 
@@ -427,7 +447,7 @@ class JsonResource(DirectServeJsonResource):
         # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
         return _unrecognised_request_handler, "unrecognised_request_handler", {}
 
-    async def _async_render(self, request):
+    async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
         # Make sure we have an appropriate name for this handler in prometheus
@@ -468,7 +488,7 @@ class DirectServeHtmlResource(_AsyncResource):
         request: SynapseRequest,
         code: int,
         response_object: Any,
-    ):
+    ) -> None:
         """Implements _AsyncResource._send_response"""
         # We expect to get bytes for us to write
         assert isinstance(response_object, bytes)
@@ -492,12 +512,12 @@ class StaticResource(File):
     Differs from the File resource by adding clickjacking protection.
     """
 
-    def render_GET(self, request: Request):
+    def render_GET(self, request: Request) -> bytes:
         set_clickjacking_protection_headers(request)
         return super().render_GET(request)
 
 
-def _unrecognised_request_handler(request):
+def _unrecognised_request_handler(request: Request) -> NoReturn:
     """Request handler for unrecognised requests
 
     This is a request handler suitable for return from
@@ -505,7 +525,7 @@ def _unrecognised_request_handler(request):
     UnrecognizedRequestError.
 
     Args:
-        request (twisted.web.http.Request):
+        request: Unused, but passed in to match the signature of ServletCallback.
     """
     raise UnrecognizedRequestError()
 
@@ -513,23 +533,23 @@ def _unrecognised_request_handler(request):
 class RootRedirect(resource.Resource):
     """Redirects the root '/' path to another path."""
 
-    def __init__(self, path):
-        resource.Resource.__init__(self)
+    def __init__(self, path: str):
+        super().__init__()
         self.url = path
 
-    def render_GET(self, request):
+    def render_GET(self, request: Request) -> bytes:
         return redirectTo(self.url.encode("ascii"), request)
 
-    def getChild(self, name, request):
+    def getChild(self, name: str, request: Request) -> resource.Resource:
         if len(name) == 0:
             return self  # select ourselves as the child to render
-        return resource.Resource.getChild(self, name, request)
+        return super().getChild(name, request)
 
 
 class OptionsResource(resource.Resource):
     """Responds to OPTION requests for itself and all children."""
 
-    def render_OPTIONS(self, request):
+    def render_OPTIONS(self, request: Request) -> bytes:
         request.setResponseCode(204)
         request.setHeader(b"Content-Length", b"0")
 
@@ -537,10 +557,10 @@ class OptionsResource(resource.Resource):
 
         return b""
 
-    def getChildWithDefault(self, path, request):
+    def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
         if request.method == b"OPTIONS":
             return self  # select ourselves as the child to render
-        return resource.Resource.getChildWithDefault(self, path, request)
+        return super().getChildWithDefault(path, request)
 
 
 class RootOptionsRedirectResource(OptionsResource, RootRedirect):
@@ -649,7 +669,7 @@ def respond_with_json(
     json_object: Any,
     send_cors: bool = False,
     canonical_json: bool = True,
-):
+) -> Optional[int]:
     """Sends encoded JSON in response to the given request.
 
     Args:
@@ -696,7 +716,7 @@ def respond_with_json_bytes(
     code: int,
     json_bytes: bytes,
     send_cors: bool = False,
-):
+) -> Optional[int]:
     """Sends encoded JSON in response to the given request.
 
     Args:
@@ -713,7 +733,7 @@ def respond_with_json_bytes(
         logger.warning(
             "Not sending response to request %s, already disconnected.", request
         )
-        return
+        return None
 
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"application/json")
@@ -731,7 +751,7 @@ async def _async_write_json_to_request_in_thread(
     request: SynapseRequest,
     json_encoder: Callable[[Any], bytes],
     json_object: Any,
-):
+) -> None:
     """Encodes the given JSON object on a thread and then writes it to the
     request.
 
@@ -743,7 +763,20 @@ async def _async_write_json_to_request_in_thread(
     expensive.
     """
 
-    json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
+    def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes:
+        # it might take a while for the threadpool to schedule us, so we write
+        # opentracing logs once we actually get scheduled, so that we can see how
+        # much that contributed.
+        if opentracing_span:
+            opentracing_span.log_kv({"event": "scheduled"})
+        res = json_encoder(json_object)
+        if opentracing_span:
+            opentracing_span.log_kv({"event": "encoded"})
+        return res
+
+    with start_active_span("encode_json_response"):
+        span = active_span()
+        json_str = await defer_to_thread(request.reactor, encode, span)
 
     _write_bytes_to_request(request, json_str)
 
@@ -773,7 +806,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
     _ByteProducer(request, bytes_generator)
 
 
-def set_cors_headers(request: Request):
+def set_cors_headers(request: Request) -> None:
     """Set the CORS headers so that javascript running in a web browsers can
     use this API
 
@@ -790,14 +823,14 @@ def set_cors_headers(request: Request):
     )
 
 
-def respond_with_html(request: Request, code: int, html: str):
+def respond_with_html(request: Request, code: int, html: str) -> None:
     """
     Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
     """
     respond_with_html_bytes(request, code, html.encode("utf-8"))
 
 
-def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
+def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
     """
     Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
 
@@ -815,7 +848,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
         logger.warning(
             "Not sending response to request %s, already disconnected.", request
         )
-        return
+        return None
 
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
@@ -828,7 +861,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
     finish_request(request)
 
 
-def set_clickjacking_protection_headers(request: Request):
+def set_clickjacking_protection_headers(request: Request) -> None:
     """
     Set headers to guard against clickjacking of embedded content.
 
@@ -850,7 +883,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
     finish_request(request)
 
 
-def finish_request(request: Request):
+def finish_request(request: Request) -> None:
     """Finish writing the response to the request.
 
     Twisted throws a RuntimeException if the connection closed before the
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6dd9b9ad03..4ff840ca0e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,6 +14,7 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 import logging
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Iterable,
@@ -30,6 +31,7 @@ from typing_extensions import Literal
 from twisted.web.server import Request
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import HttpServer
 from synapse.types import JsonDict, RoomAlias, RoomID
 from synapse.util import json_decoder
 
@@ -137,11 +139,15 @@ def parse_integer_from_args(
             return int(args[name_bytes][0])
         except Exception:
             message = "Query parameter %r must be an integer" % (name,)
-            raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
     else:
         if required:
             message = "Missing integer query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
         else:
             return default
 
@@ -246,11 +252,15 @@ def parse_boolean_from_args(
             message = (
                 "Boolean query parameter %r must be one of ['true', 'false']"
             ) % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
+            )
     else:
         if required:
             message = "Missing boolean query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
         else:
             return default
 
@@ -313,7 +323,7 @@ def parse_bytes_from_args(
         return args[name_bytes][0]
     elif required:
         message = "Missing string query parameter %s" % (name,)
-        raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
 
     return default
 
@@ -407,14 +417,16 @@ def _parse_string_value(
     try:
         value_str = value.decode(encoding)
     except ValueError:
-        raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding)
+        )
 
     if allowed_values is not None and value_str not in allowed_values:
         message = "Query parameter %r must be one of [%s]" % (
             name,
             ", ".join(repr(v) for v in allowed_values),
         )
-        raise SynapseError(400, message)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
     else:
         return value_str
 
@@ -510,7 +522,9 @@ def parse_strings_from_args(
     else:
         if required:
             message = "Missing string query parameter %r" % (name,)
-            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
+            )
 
         return default
 
@@ -638,7 +652,7 @@ def parse_json_value_from_request(
     try:
         content_bytes = request.content.read()  # type: ignore
     except Exception:
-        raise SynapseError(400, "Error reading JSON content.")
+        raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.")
 
     if not content_bytes and allow_empty_body:
         return None
@@ -647,7 +661,9 @@ def parse_json_value_from_request(
         content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
         logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
+        )
 
     return content
 
@@ -673,7 +689,7 @@ def parse_json_object_from_request(
 
     if not isinstance(content, dict):
         message = "Content must be a JSON object."
-        raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+        raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON)
 
     return content
 
@@ -685,7 +701,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
             absent.append(k)
 
     if len(absent) > 0:
-        raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM
+        )
 
 
 class RestServlet:
@@ -709,7 +727,7 @@ class RestServlet:
     into the appropriate HTTP response.
     """
 
-    def register(self, http_server):
+    def register(self, http_server: HttpServer) -> None:
         """Register this servlet with the given HTTP server."""
         patterns = getattr(self, "PATTERNS", None)
         if patterns:
@@ -758,10 +776,12 @@ class ResolveRoomIdMixin:
             resolved_room_id = room_id.to_string()
         else:
             raise SynapseError(
-                400, "%s was not legal room ID or room alias" % (room_identifier,)
+                HTTPStatus.BAD_REQUEST,
+                "%s was not legal room ID or room alias" % (room_identifier,),
             )
         if not resolved_room_id:
             raise SynapseError(
-                400, "Unknown room ID or room alias %s" % room_identifier
+                HTTPStatus.BAD_REQUEST,
+                "Unknown room ID or room alias %s" % room_identifier,
             )
         return resolved_room_id, remote_room_hosts
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 755ad56637..80f7a2ff58 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Generator, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
@@ -35,6 +35,9 @@ from synapse.logging.context import (
 )
 from synapse.types import Requester
 
+if TYPE_CHECKING:
+    import opentracing
+
 logger = logging.getLogger(__name__)
 
 _next_request_seq = 0
@@ -66,9 +69,9 @@ class SynapseRequest(Request):
         self,
         channel: HTTPChannel,
         site: "SynapseSite",
-        *args,
+        *args: Any,
         max_request_body_size: int = 1024,
-        **kw,
+        **kw: Any,
     ):
         super().__init__(channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
@@ -81,6 +84,10 @@ class SynapseRequest(Request):
         # server name, for client requests this is the Requester object.
         self._requester: Optional[Union[Requester, str]] = None
 
+        # An opentracing span for this request. Will be closed when the request is
+        # completely processed.
+        self._opentracing_span: "Optional[opentracing.Span]" = None
+
         # we can't yet create the logcontext, as we don't know the method.
         self.logcontext: Optional[LoggingContext] = None
 
@@ -148,6 +155,13 @@ class SynapseRequest(Request):
         # If there's no authenticated entity, it was the requester.
         self.logcontext.request.authenticated_entity = authenticated_entity or requester
 
+    def set_opentracing_span(self, span: "opentracing.Span") -> None:
+        """attach an opentracing span to this request
+
+        Doing so will cause the span to be closed when we finish processing the request
+        """
+        self._opentracing_span = span
+
     def get_request_id(self) -> str:
         return "%s-%i" % (self.get_method(), self.request_seq)
 
@@ -286,6 +300,9 @@ class SynapseRequest(Request):
             self._processing_finished_time = time.time()
             self._is_processing = False
 
+            if self._opentracing_span:
+                self._opentracing_span.log_kv({"event": "finished processing"})
+
             # if we've already sent the response, log it now; otherwise, we wait for the
             # response to be sent.
             if self.finish_time is not None:
@@ -299,6 +316,8 @@ class SynapseRequest(Request):
         """
         self.finish_time = time.time()
         Request.finish(self)
+        if self._opentracing_span:
+            self._opentracing_span.log_kv({"event": "response sent"})
         if not self._is_processing:
             assert self.logcontext is not None
             with PreserveLoggingContext(self.logcontext):
@@ -333,6 +352,11 @@ class SynapseRequest(Request):
         with PreserveLoggingContext(self.logcontext):
             logger.info("Connection from client lost before response was sent")
 
+            if self._opentracing_span:
+                self._opentracing_span.log_kv(
+                    {"event": "client connection lost", "reason": str(reason.value)}
+                )
+
             if not self._is_processing:
                 self._finished_processing()
 
@@ -421,6 +445,10 @@ class SynapseRequest(Request):
             usage.evt_db_fetch_count,
         )
 
+        # complete the opentracing span, if any.
+        if self._opentracing_span:
+            self._opentracing_span.finish()
+
         try:
             self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
         except Exception as e:
@@ -557,7 +585,7 @@ class SynapseSite(Site):
         proxied = config.http_options.x_forwarded
         request_class = XForwardedForRequest if proxied else SynapseRequest
 
-        def request_factory(channel, queued: bool) -> Request:
+        def request_factory(channel: HTTPChannel, queued: bool) -> Request:
             return request_class(
                 channel,
                 self,
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index d8ae3188b7..d4ee893376 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -22,20 +22,33 @@ them.
 
 See doc/log_contexts.rst for details on how this works.
 """
-import inspect
 import logging
 import threading
 import typing
 import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    overload,
+)
 
 import attr
 from typing_extensions import Literal
 
 from twisted.internet import defer, threads
+from twisted.python.threadpool import ThreadPool
 
 if TYPE_CHECKING:
     from synapse.logging.scopecontextmanager import _LogContextScope
+    from synapse.types import ISynapseReactor
 
 logger = logging.getLogger(__name__)
 
@@ -55,7 +68,6 @@ try:
     def get_thread_resource_usage() -> "Optional[resource.struct_rusage]":
         return resource.getrusage(RUSAGE_THREAD)
 
-
 except Exception:
     # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
     # won't track resource usage.
@@ -66,7 +78,7 @@ except Exception:
 
 
 # a hook which can be set during testing to assert that we aren't abusing logcontexts.
-def logcontext_error(msg: str):
+def logcontext_error(msg: str) -> None:
     logger.warning(msg)
 
 
@@ -223,22 +235,19 @@ class _Sentinel:
     def __str__(self) -> str:
         return "sentinel"
 
-    def copy_to(self, record):
-        pass
-
-    def start(self, rusage: "Optional[resource.struct_rusage]"):
+    def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
         pass
 
-    def stop(self, rusage: "Optional[resource.struct_rusage]"):
+    def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
         pass
 
-    def add_database_transaction(self, duration_sec):
+    def add_database_transaction(self, duration_sec: float) -> None:
         pass
 
-    def add_database_scheduled(self, sched_sec):
+    def add_database_scheduled(self, sched_sec: float) -> None:
         pass
 
-    def record_event_fetch(self, event_count):
+    def record_event_fetch(self, event_count: int) -> None:
         pass
 
     def __bool__(self) -> Literal[False]:
@@ -379,7 +388,12 @@ class LoggingContext:
             )
         return self
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         """Restore the logging context in thread local storage to the state it
         was before this context was entered.
         Returns:
@@ -399,17 +413,6 @@ class LoggingContext:
         # recorded against the correct metrics.
         self.finished = True
 
-    def copy_to(self, record) -> None:
-        """Copy logging fields from this context to a log record or
-        another LoggingContext
-        """
-
-        # we track the current request
-        record.request = self.request
-
-        # we also track the current scope:
-        record.scope = self.scope
-
     def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
         """
         Record that this logcontext is currently running.
@@ -626,7 +629,12 @@ class PreserveLoggingContext:
     def __enter__(self) -> None:
         self._old_context = set_current_context(self._new_context)
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         context = set_current_context(self._old_context)
 
         if context != self._new_context:
@@ -711,16 +719,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
     )
 
 
-def preserve_fn(f):
+R = TypeVar("R")
+
+
+@overload
+def preserve_fn(  # type: ignore[misc]
+    f: Callable[..., Awaitable[R]],
+) -> Callable[..., "defer.Deferred[R]"]:
+    # The `type: ignore[misc]` above suppresses
+    # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+    ...
+
+
+@overload
+def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
+    ...
+
+
+def preserve_fn(
+    f: Union[
+        Callable[..., R],
+        Callable[..., Awaitable[R]],
+    ]
+) -> Callable[..., "defer.Deferred[R]"]:
     """Function decorator which wraps the function with run_in_background"""
 
-    def g(*args, **kwargs):
+    def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
         return run_in_background(f, *args, **kwargs)
 
     return g
 
 
-def run_in_background(f, *args, **kwargs) -> defer.Deferred:
+@overload
+def run_in_background(  # type: ignore[misc]
+    f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+    # The `type: ignore[misc]` above suppresses
+    # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+    ...
+
+
+@overload
+def run_in_background(
+    f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+    ...
+
+
+def run_in_background(
+    f: Union[
+        Callable[..., R],
+        Callable[..., Awaitable[R]],
+    ],
+    *args: Any,
+    **kwargs: Any,
+) -> "defer.Deferred[R]":
     """Calls a function, ensuring that the current context is restored after
     return from the function, and that the sentinel context is set once the
     deferred returned by the function completes.
@@ -751,6 +804,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
     # At this point we should have a Deferred, if not then f was a synchronous
     # function, wrap it in a Deferred for consistency.
     if not isinstance(res, defer.Deferred):
+        # `res` is not a `Deferred` and not a `Coroutine`.
+        # There are no other types of `Awaitable`s we expect to encounter in Synapse.
+        assert not isinstance(res, Awaitable)
+
         return defer.succeed(res)
 
     if res.called and not res.paused:
@@ -778,13 +835,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
     return res
 
 
-def make_deferred_yieldable(deferred):
-    """Given a deferred (or coroutine), make it follow the Synapse logcontext
-    rules:
+T = TypeVar("T")
 
-    If the deferred has completed (or is not actually a Deferred), essentially
-    does nothing (just returns another completed deferred with the
-    result/failure).
+
+def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
+    """Given a deferred, make it follow the Synapse logcontext rules:
+
+    If the deferred has completed, essentially does nothing (just returns another
+    completed deferred with the result/failure).
 
     If the deferred has not yet completed, resets the logcontext before
     returning a deferred. Then, when the deferred completes, restores the
@@ -792,16 +850,6 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
-    if inspect.isawaitable(deferred):
-        # If we're given a coroutine we convert it to a deferred so that we
-        # run it and find out if it immediately finishes, it it does then we
-        # don't need to fiddle with log contexts at all and can return
-        # immediately.
-        deferred = defer.ensureDeferred(deferred)
-
-    if not isinstance(deferred, defer.Deferred):
-        return deferred
-
     if deferred.called and not deferred.paused:
         # it looks like this deferred is ready to run any callbacks we give it
         # immediately. We may as well optimise out the logcontext faffery.
@@ -823,7 +871,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
     return result
 
 
-def defer_to_thread(reactor, f, *args, **kwargs):
+def defer_to_thread(
+    reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
     """
     Calls the function `f` using a thread from the reactor's default threadpool and
     returns the result as a Deferred.
@@ -855,7 +905,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
     return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
 
 
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+def defer_to_threadpool(
+    reactor: "ISynapseReactor",
+    threadpool: ThreadPool,
+    f: Callable[..., R],
+    *args: Any,
+    **kwargs: Any,
+) -> "defer.Deferred[R]":
     """
     A wrapper for twisted.internet.threads.deferToThreadpool, which handles
     logcontexts correctly.
@@ -897,7 +953,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         assert isinstance(curr_context, LoggingContext)
         parent_context = curr_context
 
-    def g():
+    def g() -> R:
         with LoggingContext(str(curr_context), parent_context=parent_context):
             return f(*args, **kwargs)
 
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 20d23a4260..622445e9f4 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -173,6 +173,7 @@ from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Typ
 import attr
 
 from twisted.internet import defer
+from twisted.web.http import Request
 from twisted.web.http_headers import Headers
 
 from synapse.config import ConfigError
@@ -219,11 +220,12 @@ class _DummyTagNames:
 
 try:
     import opentracing
+    import opentracing.tags
 
     tags = opentracing.tags
 except ImportError:
-    opentracing = None
-    tags = _DummyTagNames
+    opentracing = None  # type: ignore[assignment]
+    tags = _DummyTagNames  # type: ignore[assignment]
 try:
     from jaeger_client import Config as JaegerConfig
 
@@ -366,7 +368,7 @@ def init_tracer(hs: "HomeServer"):
     global opentracing
     if not hs.config.tracing.opentracer_enabled:
         # We don't have a tracer
-        opentracing = None
+        opentracing = None  # type: ignore[assignment]
         return
 
     if not opentracing or not JaegerConfig:
@@ -452,7 +454,7 @@ def start_active_span(
     """
 
     if opentracing is None:
-        return noop_context_manager()
+        return noop_context_manager()  # type: ignore[unreachable]
 
     return opentracing.tracer.start_active_span(
         operation_name,
@@ -477,7 +479,7 @@ def start_active_span_follows_from(
            forced, the new span will also have tracing forced.
     """
     if opentracing is None:
-        return noop_context_manager()
+        return noop_context_manager()  # type: ignore[unreachable]
 
     references = [opentracing.follows_from(context) for context in contexts]
     scope = start_active_span(operation_name, references=references)
@@ -490,48 +492,6 @@ def start_active_span_follows_from(
     return scope
 
 
-def start_active_span_from_request(
-    request,
-    operation_name,
-    references=None,
-    tags=None,
-    start_time=None,
-    ignore_active_span=False,
-    finish_on_close=True,
-):
-    """
-    Extracts a span context from a Twisted Request.
-    args:
-        headers (twisted.web.http.Request)
-
-        For the other args see opentracing.tracer
-
-    returns:
-        span_context (opentracing.span.SpanContext)
-    """
-    # Twisted encodes the values as lists whereas opentracing doesn't.
-    # So, we take the first item in the list.
-    # Also, twisted uses byte arrays while opentracing expects strings.
-
-    if opentracing is None:
-        return noop_context_manager()
-
-    header_dict = {
-        k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
-    }
-    context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
-
-    return opentracing.tracer.start_active_span(
-        operation_name,
-        child_of=context,
-        references=references,
-        tags=tags,
-        start_time=start_time,
-        ignore_active_span=ignore_active_span,
-        finish_on_close=finish_on_close,
-    )
-
-
 def start_active_span_from_edu(
     edu_content,
     operation_name,
@@ -553,7 +513,7 @@ def start_active_span_from_edu(
     references = references or []
 
     if opentracing is None:
-        return noop_context_manager()
+        return noop_context_manager()  # type: ignore[unreachable]
 
     carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
         "opentracing", {}
@@ -594,18 +554,21 @@ def active_span():
 @ensure_active_span("set a tag")
 def set_tag(key, value):
     """Sets a tag on the active span"""
+    assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.set_tag(key, value)
 
 
 @ensure_active_span("log")
 def log_kv(key_values, timestamp=None):
     """Log to the active span"""
+    assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.log_kv(key_values, timestamp)
 
 
 @ensure_active_span("set the traces operation name")
 def set_operation_name(operation_name):
     """Sets the operation name of the active span"""
+    assert opentracing.tracer.active_span is not None
     opentracing.tracer.active_span.set_operation_name(operation_name)
 
 
@@ -674,6 +637,7 @@ def inject_header_dict(
     span = opentracing.tracer.active_span
 
     carrier: Dict[str, str] = {}
+    assert span is not None
     opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
 
     for key, value in carrier.items():
@@ -716,6 +680,7 @@ def get_active_span_text_map(destination=None):
         return {}
 
     carrier: Dict[str, str] = {}
+    assert opentracing.tracer.active_span is not None
     opentracing.tracer.inject(
         opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
     )
@@ -731,12 +696,27 @@ def active_span_context_as_string():
     """
     carrier: Dict[str, str] = {}
     if opentracing:
+        assert opentracing.tracer.active_span is not None
         opentracing.tracer.inject(
             opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
         )
     return json_encoder.encode(carrier)
 
 
+def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
+    """Extract an opentracing context from the headers on an HTTP request
+
+    This is useful when we have received an HTTP request from another part of our
+    system, and want to link our spans to those of the remote system.
+    """
+    if not opentracing:
+        return None
+    header_dict = {
+        k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+    }
+    return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
+
+
 @only_if_tracing
 def span_context_from_string(carrier):
     """
@@ -773,7 +753,7 @@ def trace(func=None, opname=None):
 
     def decorator(func):
         if opentracing is None:
-            return func
+            return func  # type: ignore[unreachable]
 
         _opname = opname if opname else func.__name__
 
@@ -864,7 +844,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
     """
 
     if opentracing is None:
-        yield
+        yield  # type: ignore[unreachable]
         return
 
     request_tags = {
@@ -876,10 +856,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
     }
 
     request_name = request.request_metrics.name
-    if extract_context:
-        scope = start_active_span_from_request(request, request_name)
-    else:
-        scope = start_active_span(request_name)
+    context = span_context_from_request(request) if extract_context else None
+
+    # we configure the scope not to finish the span immediately on exit, and instead
+    # pass the span into the SynapseRequest, which will finish it once we've finished
+    # sending the response to the client.
+    scope = start_active_span(request_name, child_of=context, finish_on_close=False)
+    request.set_opentracing_span(scope.span)
 
     with scope:
         inject_response_headers(request.responseHeaders)
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index b1e8e08fe9..db8ca2c049 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -71,7 +71,7 @@ class LogContextScopeManager(ScopeManager):
         if not ctx:
             # We don't want this scope to affect.
             logger.error("Tried to activate scope outside of loggingcontext")
-            return Scope(None, span)
+            return Scope(None, span)  # type: ignore[arg-type]
         elif ctx.scope is not None:
             # We want the logging scope to look exactly the same so we give it
             # a blank suffix
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 60e5409895..bbabdb0587 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 from typing import (
     Awaitable,
     Callable,
@@ -44,7 +43,13 @@ from synapse.logging.opentracing import log_kv, start_active_span
 from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
 from synapse.streams.config import PaginationConfig
-from synapse.types import PersistedEventPosition, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+    JsonDict,
+    PersistedEventPosition,
+    RoomStreamToken,
+    StreamToken,
+    UserID,
+)
 from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
 from synapse.util.metrics import Measure
 from synapse.visibility import filter_events_for_client
@@ -178,7 +183,12 @@ class _NotifierUserStream:
             return _NotificationListener(self.notify_deferred.observe())
 
 
-class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventStreamResult:
+    events: List[Union[JsonDict, EventBase]]
+    start_token: StreamToken
+    end_token: StreamToken
+
     def __bool__(self):
         return bool(self.events)
 
@@ -582,9 +592,12 @@ class Notifier:
             before_token: StreamToken, after_token: StreamToken
         ) -> EventStreamResult:
             if after_token == before_token:
-                return EventStreamResult([], (from_token, from_token))
+                return EventStreamResult([], from_token, from_token)
 
-            events: List[EventBase] = []
+            # The events fetched from each source are a JsonDict, EventBase, or
+            # UserPresenceState, but see below for UserPresenceState being
+            # converted to JsonDict.
+            events: List[Union[JsonDict, EventBase]] = []
             end_token = from_token
 
             for name, source in self.event_sources.sources.get_sources():
@@ -623,7 +636,7 @@ class Notifier:
                 events.extend(new_events)
                 end_token = end_token.copy_and_replace(keyname, new_key)
 
-            return EventStreamResult(events, (from_token, end_token))
+            return EventStreamResult(events, from_token, end_token)
 
         user_id_for_stream = user.to_string()
         if is_peeking:
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 4f13c0418a..39bb2acae4 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -177,12 +177,12 @@ class EmailPusher(Pusher):
             return
 
         for push_action in unprocessed:
-            received_at = push_action["received_ts"]
+            received_at = push_action.received_ts
             if received_at is None:
                 received_at = 0
             notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
 
-            room_ready_at = self.room_ready_to_notify_at(push_action["room_id"])
+            room_ready_at = self.room_ready_to_notify_at(push_action.room_id)
 
             should_notify_at = max(notif_ready_at, room_ready_at)
 
@@ -193,23 +193,23 @@ class EmailPusher(Pusher):
                 # to be delivered.
 
                 reason: EmailReason = {
-                    "room_id": push_action["room_id"],
+                    "room_id": push_action.room_id,
                     "now": self.clock.time_msec(),
                     "received_at": received_at,
                     "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS,
-                    "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]),
-                    "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
+                    "last_sent_ts": self.get_room_last_sent_ts(push_action.room_id),
+                    "throttle_ms": self.get_room_throttle_ms(push_action.room_id),
                 }
 
                 await self.send_notification(unprocessed, reason)
 
                 await self.save_last_stream_ordering_and_success(
-                    max(ea["stream_ordering"] for ea in unprocessed)
+                    max(ea.stream_ordering for ea in unprocessed)
                 )
 
                 # we update the throttle on all the possible unprocessed push actions
                 for ea in unprocessed:
-                    await self.sent_notif_update_throttle(ea["room_id"], ea)
+                    await self.sent_notif_update_throttle(ea.room_id, ea)
                 break
             else:
                 if soonest_due_at is None or should_notify_at < soonest_due_at:
@@ -284,10 +284,10 @@ class EmailPusher(Pusher):
         # THROTTLE_RESET_AFTER_MS after the previous one that triggered a
         # notif, we release the throttle. Otherwise, the throttle is increased.
         time_of_previous_notifs = await self.store.get_time_of_last_push_action_before(
-            notified_push_action["stream_ordering"]
+            notified_push_action.stream_ordering
         )
 
-        time_of_this_notifs = notified_push_action["received_ts"]
+        time_of_this_notifs = notified_push_action.received_ts
 
         if time_of_previous_notifs is not None and time_of_this_notifs is not None:
             gap = time_of_this_notifs - time_of_previous_notifs
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 00f42d4dcb..fece3796cc 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -204,7 +204,7 @@ class HttpPusher(Pusher):
                 "http-push",
                 tags={
                     "authenticated_entity": self.user_id,
-                    "event_id": push_action["event_id"],
+                    "event_id": push_action.event_id,
                     "app_id": self.app_id,
                     "app_display_name": self.app_display_name,
                 },
@@ -214,7 +214,7 @@ class HttpPusher(Pusher):
             if processed:
                 http_push_processed_counter.inc()
                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
-                self.last_stream_ordering = push_action["stream_ordering"]
+                self.last_stream_ordering = push_action.stream_ordering
                 pusher_still_exists = (
                     await self.store.update_pusher_last_stream_ordering_and_success(
                         self.app_id,
@@ -257,7 +257,7 @@ class HttpPusher(Pusher):
                         self.pushkey,
                     )
                     self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
-                    self.last_stream_ordering = push_action["stream_ordering"]
+                    self.last_stream_ordering = push_action.stream_ordering
                     await self.store.update_pusher_last_stream_ordering(
                         self.app_id,
                         self.pushkey,
@@ -280,17 +280,17 @@ class HttpPusher(Pusher):
                     break
 
     async def _process_one(self, push_action: HttpPushAction) -> bool:
-        if "notify" not in push_action["actions"]:
+        if "notify" not in push_action.actions:
             return True
 
-        tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
+        tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
         badge = await push_tools.get_badge_count(
             self.hs.get_datastore(),
             self.user_id,
             group_by_room=self._group_unread_count_by_room,
         )
 
-        event = await self.store.get_event(push_action["event_id"], allow_none=True)
+        event = await self.store.get_event(push_action.event_id, allow_none=True)
         if event is None:
             return True  # It's been redacted
         rejected = await self.dispatch_push(event, tweaks, badge)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index ba4f866487..ff904c2b4a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -232,15 +232,13 @@ class Mailer:
             reason: The notification that was ready and is the cause of an email
                 being sent.
         """
-        rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
+        rooms_in_order = deduped_ordered_list([pa.room_id for pa in push_actions])
 
-        notif_events = await self.store.get_events(
-            [pa["event_id"] for pa in push_actions]
-        )
+        notif_events = await self.store.get_events([pa.event_id for pa in push_actions])
 
         notifs_by_room: Dict[str, List[EmailPushAction]] = {}
         for pa in push_actions:
-            notifs_by_room.setdefault(pa["room_id"], []).append(pa)
+            notifs_by_room.setdefault(pa.room_id, []).append(pa)
 
         # collect the current state for all the rooms in which we have
         # notifications
@@ -264,7 +262,7 @@ class Mailer:
         await concurrently_execute(_fetch_room_state, rooms_in_order, 3)
 
         # actually sort our so-called rooms_in_order list, most recent room first
-        rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
+        rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1].received_ts or 0))
 
         rooms: List[RoomVars] = []
 
@@ -356,7 +354,7 @@ class Mailer:
         # Check if one of the notifs is an invite event for the user.
         is_invite = False
         for n in notifs:
-            ev = notif_events[n["event_id"]]
+            ev = notif_events[n.event_id]
             if ev.type == EventTypes.Member and ev.state_key == user_id:
                 if ev.content.get("membership") == Membership.INVITE:
                     is_invite = True
@@ -376,7 +374,7 @@ class Mailer:
         if not is_invite:
             for n in notifs:
                 notifvars = await self._get_notif_vars(
-                    n, user_id, notif_events[n["event_id"]], room_state_ids
+                    n, user_id, notif_events[n.event_id], room_state_ids
                 )
 
                 # merge overlapping notifs together.
@@ -444,15 +442,15 @@ class Mailer:
         """
 
         results = await self.store.get_events_around(
-            notif["room_id"],
-            notif["event_id"],
+            notif.room_id,
+            notif.event_id,
             before_limit=CONTEXT_BEFORE,
             after_limit=CONTEXT_AFTER,
         )
 
         ret: NotifVars = {
             "link": self._make_notif_link(notif),
-            "ts": notif["received_ts"],
+            "ts": notif.received_ts,
             "messages": [],
         }
 
@@ -516,7 +514,7 @@ class Mailer:
 
         ret: MessageVars = {
             "event_type": event.type,
-            "is_historical": event.event_id != notif["event_id"],
+            "is_historical": event.event_id != notif.event_id,
             "id": event.event_id,
             "ts": event.origin_server_ts,
             "sender_name": sender_name,
@@ -610,7 +608,7 @@ class Mailer:
         # See if one of the notifs is an invite event for the user
         invite_event = None
         for n in notifs:
-            ev = notif_events[n["event_id"]]
+            ev = notif_events[n.event_id]
             if ev.type == EventTypes.Member and ev.state_key == user_id:
                 if ev.content.get("membership") == Membership.INVITE:
                     invite_event = ev
@@ -659,7 +657,7 @@ class Mailer:
         if len(notifs) == 1:
             # There is just the one notification, so give some detail
             sender_name = None
-            event = notif_events[notifs[0]["event_id"]]
+            event = notif_events[notifs[0].event_id]
             if ("m.room.member", event.sender) in room_state_ids:
                 state_event_id = room_state_ids[("m.room.member", event.sender)]
                 state_event = await self.store.get_event(state_event_id)
@@ -753,9 +751,9 @@ class Mailer:
         # are already in descending received_ts.
         sender_ids = {}
         for n in notifs:
-            sender = notif_events[n["event_id"]].sender
+            sender = notif_events[n.event_id].sender
             if sender not in sender_ids:
-                sender_ids[sender] = n["event_id"]
+                sender_ids[sender] = n.event_id
 
         # Get the actual member events (in order to calculate a pretty name for
         # the room).
@@ -830,17 +828,17 @@ class Mailer:
         if self.hs.config.email.email_riot_base_url:
             return "%s/#/room/%s/%s" % (
                 self.hs.config.email.email_riot_base_url,
-                notif["room_id"],
-                notif["event_id"],
+                notif.room_id,
+                notif.event_id,
             )
         elif self.app_name == "Vector":
             # need /beta for Universal Links to work on iOS
             return "https://vector.im/beta/#/room/%s/%s" % (
-                notif["room_id"],
-                notif["event_id"],
+                notif.room_id,
+                notif.event_id,
             )
         else:
-            return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
+            return "https://matrix.to/#/%s/%s" % (notif.room_id, notif.event_id)
 
     def _make_unsubscribe_link(
         self, user_id: str, app_id: str, email_address: str
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 7f68092ec5..659a53805d 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -17,9 +17,10 @@ import logging
 import re
 from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
 
+from matrix_common.regex import glob_to_regex, to_word_pattern
+
 from synapse.events import EventBase
 from synapse.types import JsonDict, UserID
-from synapse.util import glob_to_regex, re_word_boundary
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -184,7 +185,7 @@ class PushRuleEvaluatorForEvent:
         r = regex_cache.get((display_name, False, True), None)
         if not r:
             r1 = re.escape(display_name)
-            r1 = re_word_boundary(r1)
+            r1 = to_word_pattern(r1)
             r = re.compile(r1, flags=re.IGNORECASE)
             regex_cache[(display_name, False, True)] = r
 
@@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
     try:
         r = regex_cache.get((glob, True, word_boundary), None)
         if not r:
-            r = glob_to_regex(glob, word_boundary)
+            r = glob_to_regex(glob, word_boundary=word_boundary)
             regex_cache[(glob, True, word_boundary)] = r
         return bool(r.search(value))
     except re.error:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 9c85200c0f..957c9b780b 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 from typing import Dict
 
+from synapse.api.constants import ReceiptTypes
 from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage import Storage
@@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     invites = await store.get_invited_rooms_for_local_user(user_id)
     joins = await store.get_rooms_for_user(user_id)
 
-    my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
+    my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
 
     badge = len(invites)
 
@@ -36,7 +37,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
                     room_id, user_id, last_unread_event_id
                 )
             )
-            if notifs["notify_count"] == 0:
+            if notifs.notify_count == 0:
                 continue
 
             if group_by_room:
@@ -44,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
                 badge += 1
             else:
                 # increment the badge count by the number of unread messages in the room
-                badge += notifs["notify_count"]
+                badge += notifs.notify_count
     return badge
 
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 26735447a6..7912311d24 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -27,6 +27,7 @@ from synapse.push.pusher import PusherFactory
 from synapse.replication.http.push import ReplicationRemovePusherRestServlet
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
+from synapse.util.threepids import canonicalise_email
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -113,7 +114,9 @@ class PusherPool:
         """
 
         if kind == "email":
-            email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
+            email_owner = await self.store.get_user_id_by_threepid(
+                "email", canonicalise_email(pushkey)
+            )
             if email_owner != user_id:
                 raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
 
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 7d26954244..d844fbb3b3 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -50,7 +50,8 @@ logger = logging.getLogger(__name__)
 REQUIREMENTS = [
     # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0
     "jsonschema>=3.0.0",
-    "frozendict>=1",
+    # frozendict 2.1.2 is broken on Debian 10: https://github.com/Marco-Sulla/python-frozendict/issues/41
+    "frozendict>=1,<2.1.2",
     "unpaddedbase64>=1.1.0",
     "canonicaljson>=1.4.0",
     # we use the type definitions added in signedjson 1.1.
@@ -87,6 +88,7 @@ REQUIREMENTS = [
     # with the latest security patches.
     "cryptography>=3.4.7",
     "ijson>=3.1",
+    "matrix-common==1.0.0",
 ]
 
 CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 7ecb446e7c..7644146dba 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Optional
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -27,7 +27,12 @@ logger = logging.getLogger(__name__)
 
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen: Optional[
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 61cd7e5228..bc888ce1a8 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,7 +14,7 @@
 
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches.lrucache import LruCache
 
@@ -25,7 +25,12 @@ if TYPE_CHECKING:
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0a58296089..a2aff75b70 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.devices import DeviceWorkerStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -27,7 +27,12 @@ if TYPE_CHECKING:
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 63ed50caa5..0f08372694 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
@@ -58,7 +58,12 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
@@ -75,12 +80,3 @@ class SlavedEventStore(
             min_curr_state_delta_id,
             prefilled_cache=curr_state_delta_prefill,
         )
-
-    # Cached functions can't be accessed through a class instance so we need
-    # to reach inside the __dict__ to extract them.
-
-    def get_room_max_stream_ordering(self):
-        return self._stream_id_gen.get_current_token()
-
-    def get_room_min_stream_ordering(self):
-        return self._backfill_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 90284c202d..4d185e2b56 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,7 +14,7 @@
 
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.filtering import FilteringStore
 
 from ._base import BaseSlavedStore
@@ -24,7 +24,12 @@ if TYPE_CHECKING:
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 497e16c69e..9d90e26375 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import GroupServerStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.group_server import GroupServerWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -26,7 +26,12 @@ if TYPE_CHECKING:
 
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 743a01da08..5a2d90c530 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -15,7 +15,6 @@
 
 import heapq
 import logging
-from collections import namedtuple
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -30,6 +29,7 @@ from typing import (
 import attr
 
 from synapse.replication.http.streams import ReplicationGetStreamUpdates
+from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -226,17 +226,14 @@ class BackfillStream(Stream):
     or it went from being an outlier to not.
     """
 
-    BackfillStreamRow = namedtuple(
-        "BackfillStreamRow",
-        (
-            "event_id",  # str
-            "room_id",  # str
-            "type",  # str
-            "state_key",  # str, optional
-            "redacts",  # str, optional
-            "relates_to",  # str, optional
-        ),
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class BackfillStreamRow:
+        event_id: str
+        room_id: str
+        type: str
+        state_key: Optional[str]
+        redacts: Optional[str]
+        relates_to: Optional[str]
 
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
@@ -256,18 +253,15 @@ class BackfillStream(Stream):
 
 
 class PresenceStream(Stream):
-    PresenceStreamRow = namedtuple(
-        "PresenceStreamRow",
-        (
-            "user_id",  # str
-            "state",  # str
-            "last_active_ts",  # int
-            "last_federation_update_ts",  # int
-            "last_user_sync_ts",  # int
-            "status_msg",  # str
-            "currently_active",  # bool
-        ),
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class PresenceStreamRow:
+        user_id: str
+        state: str
+        last_active_ts: int
+        last_federation_update_ts: int
+        last_user_sync_ts: int
+        status_msg: str
+        currently_active: bool
 
     NAME = "presence"
     ROW_TYPE = PresenceStreamRow
@@ -302,7 +296,7 @@ class PresenceFederationStream(Stream):
     send.
     """
 
-    @attr.s(slots=True, auto_attribs=True)
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
     class PresenceFederationStreamRow:
         destination: str
         user_id: str
@@ -320,9 +314,10 @@ class PresenceFederationStream(Stream):
 
 
 class TypingStream(Stream):
-    TypingStreamRow = namedtuple(
-        "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class TypingStreamRow:
+        room_id: str
+        user_ids: List[str]
 
     NAME = "typing"
     ROW_TYPE = TypingStreamRow
@@ -348,16 +343,13 @@ class TypingStream(Stream):
 
 
 class ReceiptsStream(Stream):
-    ReceiptsStreamRow = namedtuple(
-        "ReceiptsStreamRow",
-        (
-            "room_id",  # str
-            "receipt_type",  # str
-            "user_id",  # str
-            "event_id",  # str
-            "data",  # dict
-        ),
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class ReceiptsStreamRow:
+        room_id: str
+        receipt_type: str
+        user_id: str
+        event_id: str
+        data: dict
 
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
@@ -374,7 +366,9 @@ class ReceiptsStream(Stream):
 class PushRulesStream(Stream):
     """A user has changed their push rules"""
 
-    PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class PushRulesStreamRow:
+        user_id: str
 
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
@@ -396,10 +390,12 @@ class PushRulesStream(Stream):
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher"""
 
-    PushersStreamRow = namedtuple(
-        "PushersStreamRow",
-        ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class PushersStreamRow:
+        user_id: str
+        app_id: str
+        pushkey: str
+        deleted: bool
 
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
@@ -419,7 +415,7 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
-    @attr.s(slots=True)
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
     class CachesStreamRow:
         """Stream to inform workers they should invalidate their cache.
 
@@ -430,9 +426,9 @@ class CachesStream(Stream):
             invalidation_ts: Timestamp of when the invalidation took place.
         """
 
-        cache_func = attr.ib(type=str)
-        keys = attr.ib(type=Optional[List[Any]])
-        invalidation_ts = attr.ib(type=int)
+        cache_func: str
+        keys: Optional[List[Any]]
+        invalidation_ts: int
 
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
@@ -451,9 +447,9 @@ class DeviceListsStream(Stream):
     told about a device update.
     """
 
-    @attr.s(slots=True)
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
     class DeviceListsStreamRow:
-        entity = attr.ib(type=str)
+        entity: str
 
     NAME = "device_lists"
     ROW_TYPE = DeviceListsStreamRow
@@ -470,7 +466,9 @@ class DeviceListsStream(Stream):
 class ToDeviceStream(Stream):
     """New to_device messages for a client"""
 
-    ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class ToDeviceStreamRow:
+        entity: str
 
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
@@ -487,9 +485,11 @@ class ToDeviceStream(Stream):
 class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room"""
 
-    TagAccountDataStreamRow = namedtuple(
-        "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class TagAccountDataStreamRow:
+        user_id: str
+        room_id: str
+        data: JsonDict
 
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
@@ -506,10 +506,11 @@ class TagAccountDataStream(Stream):
 class AccountDataStream(Stream):
     """Global or per room account data was changed"""
 
-    AccountDataStreamRow = namedtuple(
-        "AccountDataStreamRow",
-        ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class AccountDataStreamRow:
+        user_id: str
+        room_id: Optional[str]
+        data_type: str
 
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
@@ -573,10 +574,12 @@ class AccountDataStream(Stream):
 
 
 class GroupServerStream(Stream):
-    GroupsStreamRow = namedtuple(
-        "GroupsStreamRow",
-        ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class GroupsStreamRow:
+        group_id: str
+        user_id: str
+        type: str
+        content: JsonDict
 
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
@@ -593,7 +596,9 @@ class GroupServerStream(Stream):
 class UserSignatureStream(Stream):
     """A user has signed their own device with their user-signing key"""
 
-    UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class UserSignatureStreamRow:
+        user_id: str
 
     NAME = "user_signature"
     ROW_TYPE = UserSignatureStreamRow
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 0600cdbf36..4046bdec69 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -12,14 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from collections import namedtuple
 from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple
 
+import attr
+
 from synapse.replication.tcp.streams._base import (
     Stream,
     current_token_without_instance,
     make_http_update_function,
 )
+from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -30,13 +32,10 @@ class FederationStream(Stream):
     sending disabled.
     """
 
-    FederationStreamRow = namedtuple(
-        "FederationStreamRow",
-        (
-            "type",  # str, the type of data as defined in the BaseFederationRows
-            "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-        ),
-    )
+    @attr.s(slots=True, frozen=True, auto_attribs=True)
+    class FederationStreamRow:
+        type: str  # the type of data as defined in the BaseFederationRows
+        data: JsonDict  # serialization of a federation.send_queue.BaseFederationRow
 
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index c499afd4be..465e06772b 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -69,6 +69,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
 from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
 from synapse.rest.admin.username_available import UsernameAvailableRestServlet
 from synapse.rest.admin.users import (
+    AccountDataRestServlet,
     AccountValidityRenewServlet,
     DeactivateAccountRestServlet,
     PushersRestServlet,
@@ -108,7 +109,7 @@ class VersionServlet(RestServlet):
 
 class PurgeHistoryRestServlet(RestServlet):
     PATTERNS = admin_patterns(
-        "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
+        "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
     )
 
     def __init__(self, hs: "HomeServer"):
@@ -195,7 +196,7 @@ class PurgeHistoryRestServlet(RestServlet):
 
 
 class PurgeHistoryStatusRestServlet(RestServlet):
-    PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
+    PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.pagination_handler = hs.get_pagination_handler()
@@ -255,6 +256,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     UserMediaStatisticsRestServlet(hs).register(http_server)
     EventReportDetailRestServlet(hs).register(http_server)
     EventReportsRestServlet(hs).register(http_server)
+    AccountDataRestServlet(hs).register(http_server)
     PushersRestServlet(hs).register(http_server)
     MakeRoomAdminRestServlet(hs).register(http_server)
     ShadowBanRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index 479672d4d5..6ec00ce0b9 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -22,7 +22,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
 )
 from synapse.http.site import SynapseRequest
-from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
         self._data_stores = hs.get_datastores()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         # We need to check that all configured databases have updates enabled.
         # (They *should* all be in sync.)
@@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
         return HTTPStatus.OK, {"enabled": enabled}
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         body = parse_json_object_from_request(request)
 
@@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
         self._data_stores = hs.get_datastores()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         # We need to check that all configured databases have updates enabled.
         # (They *should* all be in sync.)
@@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
 class BackgroundUpdateStartJobRestServlet(RestServlet):
     """Allows to start specific background updates"""
 
-    PATTERNS = admin_patterns("/background_updates/start_job")
+    PATTERNS = admin_patterns("/background_updates/start_job$")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
         self._store = hs.get_datastore()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self._auth.get_user_by_req(request)
-        await assert_user_is_admin(self._auth, requester.user)
+        await assert_requester_is_admin(self._auth, request)
 
         body = parse_json_object_from_request(request)
         assert_params_in_dict(body, ["job_name"])
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 2e5a6600d3..d9905ff560 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastore()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str, device_id: str
@@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet):
         device = await self.device_handler.get_device(
             target_user.to_string(), device_id
         )
+        if device is None:
+            raise NotFoundError("No device found")
         return HTTPStatus.OK, device
 
     async def on_DELETE(
@@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
 
     def __init__(self, hs: "HomeServer"):
-        """
-        Args:
-            hs: server
-        """
-        self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastore()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
@@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastore()
+        self.is_mine = hs.is_mine
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
@@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
 
         u = await self.store.get_user_by_id(target_user.to_string())
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 5ee8b11110..38477f8ead 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
     PATTERNS = admin_patterns("/event_reports$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
     PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 744687be35..50d88c9109 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
         200 OK with details of a destination if success otherwise an error.
     """
 
-    PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
+    PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index a27110388f..cd697e180e 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 class DeleteGroupAdminRestServlet(RestServlet):
     """Allows deleting of local groups"""
 
-    PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
+    PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 9e23e2d8fc..7236e4027f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,7 +17,7 @@ import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
@@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
     """
 
     PATTERNS = [
-        *admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"),
+        *admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
         # This path kept around for legacy reasons
-        *admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
+        *admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
     ]
 
     def __init__(self, hs: "HomeServer"):
@@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
     this server.
     """
 
-    PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$")
+    PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
     """
 
     PATTERNS = admin_patterns(
-        "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+        "/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
     )
 
     def __init__(self, hs: "HomeServer"):
@@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
     """
 
     PATTERNS = admin_patterns(
-        "/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+        "/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
     )
 
     def __init__(self, hs: "HomeServer"):
@@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
     async def on_POST(
         self, request: SynapseRequest, server_name: str, media_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         logging.info(
             "Remove from quarantine local media by ID: %s/%s", server_name, media_id
@@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
 class ProtectMediaByID(RestServlet):
     """Protect local media from being quarantined."""
 
-    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+    PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
     async def on_POST(
         self, request: SynapseRequest, media_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         logging.info("Protecting local media by ID: %s", media_id)
 
@@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
 class UnprotectMediaByID(RestServlet):
     """Unprotect local media from being quarantined."""
 
-    PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
+    PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
     async def on_POST(
         self, request: SynapseRequest, media_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         logging.info("Unprotecting local media by ID: %s", media_id)
 
@@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
 class ListMediaInRoom(RestServlet):
     """Lists all of the media in a given room."""
 
-    PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$")
+    PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        is_admin = await self.auth.is_server_admin(requester.user)
-        if not is_admin:
-            raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
+        await assert_requester_is_admin(self.auth, request)
 
         local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
 
@@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
 class DeleteMediaByID(RestServlet):
     """Delete local media by a given ID. Removes it from this server."""
 
-    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
+    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
     timestamp and size.
     """
 
-    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$")
+    PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
         media that exist given for this user
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
 
     def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
@@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
                 request,
                 "order_by",
                 default=MediaSortOrder.CREATED_TS.value,
-                allowed_values=(
-                    MediaSortOrder.MEDIA_ID.value,
-                    MediaSortOrder.UPLOAD_NAME.value,
-                    MediaSortOrder.CREATED_TS.value,
-                    MediaSortOrder.LAST_ACCESS_TS.value,
-                    MediaSortOrder.MEDIA_LENGTH.value,
-                    MediaSortOrder.MEDIA_TYPE.value,
-                    MediaSortOrder.QUARANTINED_BY.value,
-                    MediaSortOrder.SAFE_FROM_QUARANTINE.value,
-                ),
+                allowed_values=[sort_order.value for sort_order in MediaSortOrder],
             )
             direction = parse_string(
                 request, "dir", default="f", allowed_values=("f", "b")
@@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
                 request,
                 "order_by",
                 default=MediaSortOrder.CREATED_TS.value,
-                allowed_values=(
-                    MediaSortOrder.MEDIA_ID.value,
-                    MediaSortOrder.UPLOAD_NAME.value,
-                    MediaSortOrder.CREATED_TS.value,
-                    MediaSortOrder.LAST_ACCESS_TS.value,
-                    MediaSortOrder.MEDIA_LENGTH.value,
-                    MediaSortOrder.MEDIA_TYPE.value,
-                    MediaSortOrder.QUARANTINED_BY.value,
-                    MediaSortOrder.SAFE_FROM_QUARANTINE.value,
-                ),
+                allowed_values=[sort_order.value for sort_order in MediaSortOrder],
             )
             direction = parse_string(
                 request, "dir", default="f", allowed_values=("f", "b")
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 891b98c088..04948b6408 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
     PATTERNS = admin_patterns("/registration_tokens$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
     PATTERNS = admin_patterns("/registration_tokens/new$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
     PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.clock = hs.get_clock()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 829e86675a..6030373ebc 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
     If 'purge' is true, it will remove all traces of a room from the database.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
@@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet):
 class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
     """Get the status of the delete room background task."""
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
@@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
 class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
     """Get the status of the delete room background task."""
 
-    PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
+    PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
@@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet):
         self.admin_handler = hs.get_admin_handler()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         # Extract query parameters
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
-        order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
-        if order_by not in (
-            RoomSortOrder.ALPHABETICAL.value,
-            RoomSortOrder.SIZE.value,
-            RoomSortOrder.NAME.value,
-            RoomSortOrder.CANONICAL_ALIAS.value,
-            RoomSortOrder.JOINED_MEMBERS.value,
-            RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
-            RoomSortOrder.VERSION.value,
-            RoomSortOrder.CREATOR.value,
-            RoomSortOrder.ENCRYPTION.value,
-            RoomSortOrder.FEDERATABLE.value,
-            RoomSortOrder.PUBLIC.value,
-            RoomSortOrder.JOIN_RULES.value,
-            RoomSortOrder.GUEST_ACCESS.value,
-            RoomSortOrder.HISTORY_VISIBILITY.value,
-            RoomSortOrder.STATE_EVENTS.value,
-        ):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Unknown value for order_by: %s" % (order_by,),
-                errcode=Codes.INVALID_PARAM,
-            )
+        order_by = parse_string(
+            request,
+            "order_by",
+            default=RoomSortOrder.NAME.value,
+            allowed_values=[sort_order.value for sort_order in RoomSortOrder],
+        )
 
         search_term = parse_string(request, "search_term", encoding="utf-8")
         if search_term == "":
@@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
     TODO: Add on_POST to allow room creation without joining the room
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
     Get members list of a room.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
     Get full state within a room.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
@@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room(room_id)
         if not ret:
@@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet):
 
 class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
 
-    PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+    PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
         self.state_handler = hs.get_state_handler()
+        self.is_mine = hs.is_mine
 
     async def on_POST(
         self, request: SynapseRequest, room_identifier: str
@@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
         assert_params_in_dict(content, ["user_id"])
         target_user = UserID.from_string(content["user_id"])
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "This endpoint can only be used with local users",
@@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         }
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.event_creation_handler = hs.get_event_creation_handler()
@@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
         GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+    PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$")
 
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
     async def on_DELETE(
         self, request: SynapseRequest, room_identifier: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         room_id, _ = await self.resolve_room_id(room_identifier)
 
@@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
     async def on_GET(
         self, request: SynapseRequest, room_identifier: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
+        await assert_requester_is_admin(self.auth, request)
 
         room_id, _ = await self.resolve_room_id(room_identifier)
 
@@ -771,13 +745,19 @@ class RoomEventContextServlet(RestServlet):
 
         time_now = self.clock.time_msec()
         results["events_before"] = await self._event_serializer.serialize_events(
-            results["events_before"], time_now
+            results["events_before"],
+            time_now,
+            bundle_aggregations=True,
         )
         results["event"] = await self._event_serializer.serialize_event(
-            results["event"], time_now
+            results["event"],
+            time_now,
+            bundle_aggregations=True,
         )
         results["events_after"] = await self._event_serializer.serialize_events(
-            results["events_after"], time_now
+            results["events_after"],
+            time_now,
+            bundle_aggregations=True,
         )
         results["state"] = await self._event_serializer.serialize_events(
             results["state"], time_now
@@ -793,7 +773,7 @@ class BlockRoomRestServlet(RestServlet):
     On GET: Get blocking status of room and user who has blocked this room.
     """
 
-    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
+    PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$")
 
     def __init__(self, hs: "HomeServer"):
         self._auth = hs.get_auth()
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index b295fb078b..15da9cd881 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.server_notices_manager = hs.get_server_notices_manager()
         self.admin_handler = hs.get_admin_handler()
         self.txns = HttpTransactionCache(hs)
+        self.is_mine = hs.is_mine
 
     def register(self, json_resource: HttpServer) -> None:
         PATTERN = "/send_server_notice"
@@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet):
             )
 
         target_user = UserID.from_string(body["user_id"])
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
             )
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index ca41fd45f2..7a6546372e 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
     PATTERNS = admin_patterns("/statistics/users/media$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
@@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet):
         await assert_requester_is_admin(self.auth, request)
 
         order_by = parse_string(
-            request, "order_by", default=UserSortOrder.USER_ID.value
+            request,
+            "order_by",
+            default=UserSortOrder.USER_ID.value,
+            allowed_values=(
+                UserSortOrder.MEDIA_LENGTH.value,
+                UserSortOrder.MEDIA_COUNT.value,
+                UserSortOrder.USER_ID.value,
+                UserSortOrder.DISPLAYNAME.value,
+            ),
         )
-        if order_by not in (
-            UserSortOrder.MEDIA_LENGTH.value,
-            UserSortOrder.MEDIA_COUNT.value,
-            UserSortOrder.USER_ID.value,
-            UserSortOrder.DISPLAYNAME.value,
-        ):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Unknown value for order_by: %s" % (order_by,),
-                errcode=Codes.INVALID_PARAM,
-            )
 
         start = parse_integer(request, "from", default=0)
         if start < 0:
diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py
index 2bf1472967..5353dc3682 100644
--- a/synapse/rest/admin/username_available.py
+++ b/synapse/rest/admin/username_available.py
@@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
         }
     """
 
-    PATTERNS = admin_patterns("/username_available")
+    PATTERNS = admin_patterns("/username_available$")
 
     def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2a60b602b1..78e795c347 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
@@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet):
 
 
 class UserRestServletV2(RestServlet):
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2")
 
     """Get request to list user details.
     This needs user to have administrator access in Synapse.
@@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet):
              nonce to the time it was generated, in int seconds.
     """
 
-    PATTERNS = admin_patterns("/register")
+    PATTERNS = admin_patterns("/register$")
     NONCE_TIMEOUT = 60
 
     def __init__(self, hs: "HomeServer"):
@@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet):
     ]
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet):
         if target_user != auth_user:
             await assert_user_is_admin(self.auth, auth_user)
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
 
         ret = await self.admin_handler.get_whois(target_user)
@@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet):
 
 
 class DeactivateAccountRestServlet(RestServlet):
-    PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+    PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet):
     PATTERNS = admin_patterns("/account_validity/validity$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.account_activity_handler = hs.get_account_validity_handler()
         self.auth = hs.get_auth()
 
@@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet):
             200 OK with empty object if success otherwise an error.
     """
 
-    PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
+    PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self._set_password_handler = hs.get_set_password_handler()
@@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet):
             200 OK with json object {list[dict[str, Any]], count} or empty object.
     """
 
-    PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
+    PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, target_user_id: str
@@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet):
         # if not is_admin and target_user != auth_user:
         #     raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
 
         term = parse_string(request, "term", required=True)
@@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine = hs.is_mine
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
@@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet):
 
         target_user = UserID.from_string(user_id)
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Only local users can be admins of this homeserver",
@@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet):
 
         assert_params_in_dict(body, ["admin"])
 
-        if not self.hs.is_mine(target_user):
+        if not self.is_mine(target_user):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Only local users can be admins of this homeserver",
@@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet):
     Get room list of an user.
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$")
 
     def __init__(self, hs: "HomeServer"):
         self.is_mine = hs.is_mine
@@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
+        self.is_mine_id = hs.is_mine_id
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
@@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet):
         await assert_user_is_admin(self.auth, requester.user)
         auth_user = requester.user
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
             )
@@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet):
         {}
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine_id = hs.is_mine_id
 
     async def on_POST(
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
             )
@@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
             )
@@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet):
         }
     """
 
-    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self.is_mine_id = hs.is_mine_id
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
 
         if not await self.store.get_user_by_id(user_id):
@@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
             )
@@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if not self.hs.is_mine_id(user_id):
+        if not self.is_mine_id(user_id):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
             )
@@ -1124,3 +1121,33 @@ class RateLimitRestServlet(RestServlet):
         await self.store.delete_ratelimit_for_user(user_id)
 
         return HTTPStatus.OK, {}
+
+
+class AccountDataRestServlet(RestServlet):
+    """Retrieve the given user's account data"""
+
+    PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/accountdata")
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth = hs.get_auth()
+        self._store = hs.get_datastore()
+        self._is_mine_id = hs.is_mine_id
+
+    async def on_GET(
+        self, request: SynapseRequest, user_id: str
+    ) -> Tuple[int, JsonDict]:
+        await assert_requester_is_admin(self._auth, request)
+
+        if not self._is_mine_id(user_id):
+            raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
+
+        if not await self._store.get_user_by_id(user_id):
+            raise NotFoundError("User not found")
+
+        global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
+        return HTTPStatus.OK, {
+            "account_data": {
+                "global": global_data,
+                "rooms": by_room_data,
+            },
+        }
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8566dc5cb5..ad6fd6492b 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -17,6 +17,7 @@ import logging
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api import errors
+from synapse.api.errors import NotFoundError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
@@ -24,10 +25,9 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
 )
 from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.types import JsonDict
 
-from ._base import client_patterns, interactive_auth_handler
-
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
@@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet):
         device = await self.device_handler.get_device(
             requester.user.to_string(), device_id
         )
+        if device is None:
+            raise NotFoundError("No device found")
         return 200, device
 
     @interactive_auth_handler
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index d1d8a984c6..acd0c9e135 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -15,6 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.api.constants import ReceiptTypes
 from synapse.events.utils import format_event_for_client_v2_without_room_id
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -54,10 +55,10 @@ class NotificationsServlet(RestServlet):
         )
 
         receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
-            user_id, "m.read"
+            user_id, ReceiptTypes.READ
         )
 
-        notif_event_ids = [pa["event_id"] for pa in push_actions]
+        notif_event_ids = [pa.event_id for pa in push_actions]
         notif_events = await self.store.get_events(notif_event_ids)
 
         returned_push_actions = []
@@ -66,30 +67,30 @@ class NotificationsServlet(RestServlet):
 
         for pa in push_actions:
             returned_pa = {
-                "room_id": pa["room_id"],
-                "profile_tag": pa["profile_tag"],
-                "actions": pa["actions"],
-                "ts": pa["received_ts"],
+                "room_id": pa.room_id,
+                "profile_tag": pa.profile_tag,
+                "actions": pa.actions,
+                "ts": pa.received_ts,
                 "event": (
                     await self._event_serializer.serialize_event(
-                        notif_events[pa["event_id"]],
+                        notif_events[pa.event_id],
                         self.clock.time_msec(),
                         event_format=format_event_for_client_v2_without_room_id,
                     )
                 ),
             }
 
-            if pa["room_id"] not in receipts_by_room:
+            if pa.room_id not in receipts_by_room:
                 returned_pa["read"] = False
             else:
-                receipt = receipts_by_room[pa["room_id"]]
+                receipt = receipts_by_room[pa.room_id]
 
                 returned_pa["read"] = (
                     receipt["topological_ordering"],
                     receipt["stream_ordering"],
-                ) >= (pa["topological_ordering"], pa["stream_ordering"])
+                ) >= (pa.topological_ordering, pa.stream_ordering)
             returned_push_actions.append(returned_pa)
-            next_token = str(pa["stream_ordering"])
+            next_token = str(pa.stream_ordering)
 
         return 200, {"notifications": returned_push_actions, "next_token": next_token}
 
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 43c04fac6f..f51be511d1 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
         await self.presence_handler.bump_presence_active_time(requester.user)
 
         body = parse_json_object_from_request(request)
-        read_event_id = body.get("m.read", None)
+        read_event_id = body.get(ReceiptTypes.READ, None)
         hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
 
         if not isinstance(hidden, bool):
@@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet):
         if read_event_id:
             await self.receipts_handler.received_client_receipt(
                 room_id,
-                "m.read",
+                ReceiptTypes.READ,
                 user_id=requester.user.to_string(),
                 event_id=read_event_id,
                 hidden=hidden,
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 2b25b9aad6..b24ad2d1be 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -16,7 +16,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http import get_request_user_agent
 from synapse.http.server import HttpServer
@@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if receipt_type != "m.read":
+        if receipt_type != ReceiptTypes.READ:
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
         # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index fc4e6921c5..5815650ee6 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
 
             pagination_chunk = await self.store.get_relations_for_event(
                 event_id=parent_id,
+                room_id=room_id,
                 relation_type=relation_type,
                 event_type=event_type,
                 limit=limit,
@@ -231,7 +232,9 @@ class RelationPaginationServlet(RestServlet):
         )
         # The relations returned for the requested event do include their
         # bundled aggregations.
-        serialized_events = await self._event_serializer.serialize_events(events, now)
+        serialized_events = await self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=True
+        )
 
         return_value = pagination_chunk.to_dict()
         return_value["chunk"] = serialized_events
@@ -317,6 +320,7 @@ class RelationAggregationPaginationServlet(RestServlet):
 
             pagination_chunk = await self.store.get_aggregation_groups_for_event(
                 event_id=parent_id,
+                room_id=room_id,
                 event_type=event_type,
                 limit=limit,
                 from_token=from_token,
@@ -383,7 +387,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
 
         # This checks that a) the event exists and b) the user is allowed to
         # view it.
-        await self.event_handler.get_event(requester.user, room_id, parent_id)
+        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+        if event is None:
+            raise SynapseError(404, "Unknown parent event.")
 
         if relation_type != RelationTypes.ANNOTATION:
             raise SynapseError(400, "Relation type must be 'annotation'")
@@ -402,6 +408,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
+            room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
             aggregation_key=key,
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index f48e2e6ca2..40330749e5 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
         state_key: str,
         txn_id: Optional[str] = None,
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         if txn_id:
             set_tag("txn_id", txn_id)
@@ -662,7 +662,9 @@ class RoomEventServlet(RestServlet):
 
         time_now = self.clock.time_msec()
         if event:
-            event_dict = await self._event_serializer.serialize_event(event, time_now)
+            event_dict = await self._event_serializer.serialize_event(
+                event, time_now, bundle_aggregations=True
+            )
             return 200, event_dict
 
         raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -707,13 +709,13 @@ class RoomEventContextServlet(RestServlet):
 
         time_now = self.clock.time_msec()
         results["events_before"] = await self._event_serializer.serialize_events(
-            results["events_before"], time_now
+            results["events_before"], time_now, bundle_aggregations=True
         )
         results["event"] = await self._event_serializer.serialize_event(
-            results["event"], time_now
+            results["event"], time_now, bundle_aggregations=True
         )
         results["events_after"] = await self._event_serializer.serialize_events(
-            results["events_after"], time_now
+            results["events_after"], time_now, bundle_aggregations=True
         )
         results["state"] = await self._event_serializer.serialize_events(
             results["state"], time_now
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e556ff93e6..e99a943d0d 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -48,6 +48,7 @@ from synapse.handlers.sync import (
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
+from synapse.logging.opentracing import trace
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -222,6 +223,7 @@ class SyncRestServlet(RestServlet):
         logger.debug("Event formatting complete")
         return 200, response_content
 
+    @trace(opname="sync.encode_response")
     async def encode_response(
         self,
         time_now: int,
@@ -293,6 +295,9 @@ class SyncRestServlet(RestServlet):
         response[
             "org.matrix.msc2732.device_unused_fallback_key_types"
         ] = sync_result.device_unused_fallback_key_types
+        response[
+            "device_unused_fallback_key_types"
+        ] = sync_result.device_unused_fallback_key_types
 
         if joined:
             response["rooms"][Membership.JOIN] = joined
@@ -329,6 +334,7 @@ class SyncRestServlet(RestServlet):
             ]
         }
 
+    @trace(opname="sync.encode_joined")
     async def encode_joined(
         self,
         rooms: List[JoinedSyncResult],
@@ -365,6 +371,7 @@ class SyncRestServlet(RestServlet):
 
         return joined
 
+    @trace(opname="sync.encode_invited")
     async def encode_invited(
         self,
         rooms: List[InvitedSyncResult],
@@ -403,6 +410,7 @@ class SyncRestServlet(RestServlet):
 
         return invited
 
+    @trace(opname="sync.encode_knocked")
     async def encode_knocked(
         self,
         rooms: List[KnockedSyncResult],
@@ -457,6 +465,7 @@ class SyncRestServlet(RestServlet):
 
         return knocked
 
+    @trace(opname="sync.encode_archived")
     async def encode_archived(
         self,
         rooms: List[ArchivedSyncResult],
@@ -528,6 +537,8 @@ class SyncRestServlet(RestServlet):
                 # overhead for initialsyncs. We need to figure out a way that the
                 # bundling can be done *before* the events are stored in the
                 # SyncResponseCache so that this part can be synchronous.
+                #
+                # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations.
                 bundle_aggregations=False,
                 token_id=token_id,
                 event_format=event_formatter,
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 8d888f4565..2290c57c12 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -93,6 +93,10 @@ class VersionsRestServlet(RestServlet):
                     "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
                     # Supports receiving hidden read receipts as per MSC2285
                     "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
+                    # Adds support for importing historical messages as per MSC2716
+                    "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
+                    # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
+                    "org.matrix.msc3030": self.config.experimental.msc3030_enabled,
                 },
             },
         )
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 12b3ae120c..b9bfbea21b 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
 
 from canonicaljson import encode_canonical_json
 from signedjson.sign import sign_json
@@ -99,7 +99,7 @@ class LocalKey(Resource):
             json_object = sign_json(json_object, self.config.server.server_name, key)
         return json_object
 
-    def render_GET(self, request: Request) -> int:
+    def render_GET(self, request: Request) -> Optional[int]:
         time_now = self.clock.time_msec()
         # Update the expiry time if less than half the interval remains.
         if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 244ba261bb..71b9a34b14 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -739,14 +739,21 @@ class MediaRepository:
         # We deduplicate the thumbnail sizes by ignoring the cropped versions if
         # they have the same dimensions of a scaled one.
         thumbnails: Dict[Tuple[int, int, str], str] = {}
-        for r_width, r_height, r_method, r_type in requirements:
-            if r_method == "crop":
-                thumbnails.setdefault((r_width, r_height, r_type), r_method)
-            elif r_method == "scale":
-                t_width, t_height = thumbnailer.aspect(r_width, r_height)
+        for requirement in requirements:
+            if requirement.method == "crop":
+                thumbnails.setdefault(
+                    (requirement.width, requirement.height, requirement.media_type),
+                    requirement.method,
+                )
+            elif requirement.method == "scale":
+                t_width, t_height = thumbnailer.aspect(
+                    requirement.width, requirement.height
+                )
                 t_width = min(m_width, t_width)
                 t_height = min(m_height, t_height)
-                thumbnails[(t_width, t_height, r_type)] = r_method
+                thumbnails[
+                    (t_width, t_height, requirement.media_type)
+                ] = requirement.method
 
         # Now we generate the thumbnails for each dimension, store it
         for (t_width, t_height, t_type), t_method in thumbnails.items():
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2a59552c20..cce1527ed9 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional
 
 import attr
 
+from synapse.rest.media.v1.preview_html import parse_html_description
 from synapse.types import JsonDict
 from synapse.util import json_decoder
 
@@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
     if video_urls:
         open_graph_response["og:video"] = video_urls[0]
 
-    from synapse.rest.media.v1.preview_url_resource import _calc_description
-
-    description = _calc_description(tree)
+    description = parse_html_description(tree)
     if description:
         open_graph_response["og:description"] = description
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
new file mode 100644
index 0000000000..30b067dd42
--- /dev/null
+++ b/synapse/rest/media/v1/preview_html.py
@@ -0,0 +1,397 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import itertools
+import logging
+import re
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
+from urllib import parse as urlparse
+
+if TYPE_CHECKING:
+    from lxml import etree
+
+logger = logging.getLogger(__name__)
+
+_charset_match = re.compile(
+    br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
+)
+_xml_encoding_match = re.compile(
+    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
+)
+_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+
+
+def _normalise_encoding(encoding: str) -> Optional[str]:
+    """Use the Python codec's name as the normalised entry."""
+    try:
+        return codecs.lookup(encoding).name
+    except LookupError:
+        return None
+
+
+def _get_html_media_encodings(
+    body: bytes, content_type: Optional[str]
+) -> Iterable[str]:
+    """
+    Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
+
+    The precedence used for finding a character encoding is:
+
+    1. <meta> tag with a charset declared.
+    2. The XML document's character encoding attribute.
+    3. The Content-Type header.
+    4. Fallback to utf-8.
+    5. Fallback to windows-1252.
+
+    This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
+
+    Args:
+        body: The HTML document, as bytes.
+        content_type: The Content-Type header.
+
+    Returns:
+        The character encoding of the body, as a string.
+    """
+    # There's no point in returning an encoding more than once.
+    attempted_encodings: Set[str] = set()
+
+    # Limit searches to the first 1kb, since it ought to be at the top.
+    body_start = body[:1024]
+
+    # Check if it has an encoding set in a meta tag.
+    match = _charset_match.search(body_start)
+    if match:
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding:
+            attempted_encodings.add(encoding)
+            yield encoding
+
+    # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+    # Check if it has an XML document with an encoding.
+    match = _xml_encoding_match.match(body_start)
+    if match:
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding and encoding not in attempted_encodings:
+            attempted_encodings.add(encoding)
+            yield encoding
+
+    # Check the HTTP Content-Type header for a character set.
+    if content_type:
+        content_match = _content_type_match.match(content_type)
+        if content_match:
+            encoding = _normalise_encoding(content_match.group(1))
+            if encoding and encoding not in attempted_encodings:
+                attempted_encodings.add(encoding)
+                yield encoding
+
+    # Finally, fallback to UTF-8, then windows-1252.
+    for fallback in ("utf-8", "cp1252"):
+        if fallback not in attempted_encodings:
+            yield fallback
+
+
+def decode_body(
+    body: bytes, uri: str, content_type: Optional[str] = None
+) -> Optional["etree.Element"]:
+    """
+    This uses lxml to parse the HTML document.
+
+    Args:
+        body: The HTML document, as bytes.
+        uri: The URI used to download the body.
+        content_type: The Content-Type header.
+
+    Returns:
+        The parsed HTML body, or None if an error occurred during processed.
+    """
+    # If there's no body, nothing useful is going to be found.
+    if not body:
+        return None
+
+    # The idea here is that multiple encodings are tried until one works.
+    # Unfortunately the result is never used and then LXML will decode the string
+    # again with the found encoding.
+    for encoding in _get_html_media_encodings(body, content_type):
+        try:
+            body.decode(encoding)
+        except Exception:
+            pass
+        else:
+            break
+    else:
+        logger.warning("Unable to decode HTML body for %s", uri)
+        return None
+
+    from lxml import etree
+
+    # Create an HTML parser.
+    parser = etree.HTMLParser(recover=True, encoding=encoding)
+
+    # Attempt to parse the body. Returns None if the body was successfully
+    # parsed, but no tree was found.
+    return etree.fromstring(body, parser)
+
+
+def parse_html_to_open_graph(
+    tree: "etree.Element", media_uri: str
+) -> Dict[str, Optional[str]]:
+    """
+    Parse the HTML document into an Open Graph response.
+
+    This uses lxml to search the HTML document for Open Graph data (or
+    synthesizes it from the document).
+
+    Args:
+        tree: The parsed HTML document.
+        media_url: The URI used to download the body.
+
+    Returns:
+        The Open Graph response as a dictionary.
+    """
+
+    # if we see any image URLs in the OG response, then spider them
+    # (although the client could choose to do this by asking for previews of those
+    # URLs to avoid DoSing the server)
+
+    # "og:type"         : "video",
+    # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
+    # "og:site_name"    : "YouTube",
+    # "og:video:type"   : "application/x-shockwave-flash",
+    # "og:description"  : "Fun stuff happening here",
+    # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
+    # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
+    # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
+    # "og:video:width"  : "1280"
+    # "og:video:height" : "720",
+    # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
+
+    og: Dict[str, Optional[str]] = {}
+    for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
+        if "content" in tag.attrib:
+            # if we've got more than 50 tags, someone is taking the piss
+            if len(og) >= 50:
+                logger.warning("Skipping OG for page with too many 'og:' tags")
+                return {}
+            og[tag.attrib["property"]] = tag.attrib["content"]
+
+    # TODO: grab article: meta tags too, e.g.:
+
+    # "article:publisher" : "https://www.facebook.com/thethudonline" />
+    # "article:author" content="https://www.facebook.com/thethudonline" />
+    # "article:tag" content="baby" />
+    # "article:section" content="Breaking News" />
+    # "article:published_time" content="2016-03-31T19:58:24+00:00" />
+    # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
+
+    if "og:title" not in og:
+        # do some basic spidering of the HTML
+        title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
+        if title and title[0].text is not None:
+            og["og:title"] = title[0].text.strip()
+        else:
+            og["og:title"] = None
+
+    if "og:image" not in og:
+        # TODO: extract a favicon failing all else
+        meta_image = tree.xpath(
+            "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
+        )
+        if meta_image:
+            og["og:image"] = rebase_url(meta_image[0], media_uri)
+        else:
+            # TODO: consider inlined CSS styles as well as width & height attribs
+            images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
+            images = sorted(
+                images,
+                key=lambda i: (
+                    -1 * float(i.attrib["width"]) * float(i.attrib["height"])
+                ),
+            )
+            if not images:
+                images = tree.xpath("//img[@src]")
+            if images:
+                og["og:image"] = images[0].attrib["src"]
+
+    if "og:description" not in og:
+        meta_description = tree.xpath(
+            "//*/meta"
+            "[translate(@name, 'DESCRIPTION', 'description')='description']"
+            "/@content"
+        )
+        if meta_description:
+            og["og:description"] = meta_description[0]
+        else:
+            og["og:description"] = parse_html_description(tree)
+    elif og["og:description"]:
+        # This must be a non-empty string at this point.
+        assert isinstance(og["og:description"], str)
+        og["og:description"] = summarize_paragraphs([og["og:description"]])
+
+    # TODO: delete the url downloads to stop diskfilling,
+    # as we only ever cared about its OG
+    return og
+
+
+def parse_html_description(tree: "etree.Element") -> Optional[str]:
+    """
+    Calculate a text description based on an HTML document.
+
+    Grabs any text nodes which are inside the <body/> tag, unless they are within
+    an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
+    if they are within a <script/> or <style/> tag.
+
+    This is a very very very coarse approximation to a plain text render of the page.
+
+    Args:
+        tree: The parsed HTML document.
+
+    Returns:
+        The plain text description, or None if one cannot be generated.
+    """
+    # We don't just use XPATH here as that is slow on some machines.
+
+    from lxml import etree
+
+    TAGS_TO_REMOVE = (
+        "header",
+        "nav",
+        "aside",
+        "footer",
+        "script",
+        "noscript",
+        "style",
+        etree.Comment,
+    )
+
+    # Split all the text nodes into paragraphs (by splitting on new
+    # lines)
+    text_nodes = (
+        re.sub(r"\s+", "\n", el).strip()
+        for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+    )
+    return summarize_paragraphs(text_nodes)
+
+
+def _iterate_over_text(
+    tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
+    """Iterate over the tree returning text nodes in a depth first fashion,
+    skipping text nodes inside certain tags.
+    """
+    # This is basically a stack that we extend using itertools.chain.
+    # This will either consist of an element to iterate over *or* a string
+    # to be returned.
+    elements = iter([tree])
+    while True:
+        el = next(elements, None)
+        if el is None:
+            return
+
+        if isinstance(el, str):
+            yield el
+        elif el.tag not in tags_to_ignore:
+            # el.text is the text before the first child, so we can immediately
+            # return it if the text exists.
+            if el.text:
+                yield el.text
+
+            # We add to the stack all the elements children, interspersed with
+            # each child's tail text (if it exists). The tail text of a node
+            # is text that comes *after* the node, so we always include it even
+            # if we ignore the child node.
+            elements = itertools.chain(
+                itertools.chain.from_iterable(  # Basically a flatmap
+                    [child, child.tail] if child.tail else [child]
+                    for child in el.iterchildren()
+                ),
+                elements,
+            )
+
+
+def rebase_url(url: str, base: str) -> str:
+    base_parts = list(urlparse.urlparse(base))
+    url_parts = list(urlparse.urlparse(url))
+    if not url_parts[0]:  # fix up schema
+        url_parts[0] = base_parts[0] or "http"
+    if not url_parts[1]:  # fix up hostname
+        url_parts[1] = base_parts[1]
+        if not url_parts[2].startswith("/"):
+            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+    return urlparse.urlunparse(url_parts)
+
+
+def summarize_paragraphs(
+    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
+    """
+    Try to get a summary respecting first paragraph and then word boundaries.
+
+    Args:
+        text_nodes: The paragraphs to summarize.
+        min_size: The minimum number of words to include.
+        max_size: The maximum number of words to include.
+
+    Returns:
+        A summary of the text nodes, or None if that was not possible.
+    """
+
+    # TODO: Respect sentences?
+
+    description = ""
+
+    # Keep adding paragraphs until we get to the MIN_SIZE.
+    for text_node in text_nodes:
+        if len(description) < min_size:
+            text_node = re.sub(r"[\t \r\n]+", " ", text_node)
+            description += text_node + "\n\n"
+        else:
+            break
+
+    description = description.strip()
+    description = re.sub(r"[\t ]+", " ", description)
+    description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
+
+    # If the concatenation of paragraphs to get above MIN_SIZE
+    # took us over MAX_SIZE, then we need to truncate mid paragraph
+    if len(description) > max_size:
+        new_desc = ""
+
+        # This splits the paragraph into words, but keeping the
+        # (preceding) whitespace intact so we can easily concat
+        # words back together.
+        for match in re.finditer(r"\s*\S+", description):
+            word = match.group()
+
+            # Keep adding words while the total length is less than
+            # MAX_SIZE.
+            if len(word) + len(new_desc) < max_size:
+                new_desc += word
+            else:
+                # At this point the next word *will* take us over
+                # MAX_SIZE, but we also want to ensure that its not
+                # a huge word. If it is add it anyway and we'll
+                # truncate later.
+                if len(new_desc) < min_size:
+                    new_desc += word
+                break
+
+        # Double check that we're not over the limit
+        if len(new_desc) > max_size:
+            new_desc = new_desc[:max_size]
+
+        # We always add an ellipsis because at the very least
+        # we chopped mid paragraph.
+        description = new_desc.strip() + "…"
+    return description if description else None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 054f3c296d..a3829d943b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,18 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import codecs
 import datetime
 import errno
 import fnmatch
-import itertools
 import logging
 import os
 import re
 import shutil
 import sys
 import traceback
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple
 from urllib import parse as urlparse
 
 import attr
@@ -45,6 +43,11 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.media.v1._base import get_filename_from_headers
 from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.preview_html import (
+    decode_body,
+    parse_html_to_open_graph,
+    rebase_url,
+)
 from synapse.types import JsonDict, UserID
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
@@ -54,21 +57,11 @@ from synapse.util.stringutils import random_string
 from ._base import FileInfo
 
 if TYPE_CHECKING:
-    from lxml import etree
-
     from synapse.rest.media.v1.media_repository import MediaRepository
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
-_charset_match = re.compile(
-    br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
-)
-_xml_encoding_match = re.compile(
-    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
-)
-_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
-
 OG_TAG_NAME_MAXLEN = 50
 OG_TAG_VALUE_MAXLEN = 1000
 
@@ -311,7 +304,7 @@ class PreviewUrlResource(DirectServeJsonResource):
                 # If there was no oEmbed URL (or oEmbed parsing failed), attempt
                 # to generate the Open Graph information from the HTML.
                 if not oembed_url or not og:
-                    og = _calc_og(tree, media_info.uri)
+                    og = parse_html_to_open_graph(tree, media_info.uri)
 
                 await self._precache_image_url(user, media_info, og)
             else:
@@ -468,7 +461,7 @@ class PreviewUrlResource(DirectServeJsonResource):
         # request itself and benefit from the same caching etc.  But for now we
         # just rely on the caching on the master request to speed things up.
         image_info = await self._download_url(
-            _rebase_url(og["og:image"], media_info.uri), user
+            rebase_url(og["og:image"], media_info.uri), user
         )
 
         if _is_media(image_info.media_type):
@@ -632,301 +625,6 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def _normalise_encoding(encoding: str) -> Optional[str]:
-    """Use the Python codec's name as the normalised entry."""
-    try:
-        return codecs.lookup(encoding).name
-    except LookupError:
-        return None
-
-
-def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]:
-    """
-    Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
-
-    The precedence used for finding a character encoding is:
-
-    1. <meta> tag with a charset declared.
-    2. The XML document's character encoding attribute.
-    3. The Content-Type header.
-    4. Fallback to utf-8.
-    5. Fallback to windows-1252.
-
-    This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
-
-    Args:
-        body: The HTML document, as bytes.
-        content_type: The Content-Type header.
-
-    Returns:
-        The character encoding of the body, as a string.
-    """
-    # There's no point in returning an encoding more than once.
-    attempted_encodings: Set[str] = set()
-
-    # Limit searches to the first 1kb, since it ought to be at the top.
-    body_start = body[:1024]
-
-    # Check if it has an encoding set in a meta tag.
-    match = _charset_match.search(body_start)
-    if match:
-        encoding = _normalise_encoding(match.group(1).decode("ascii"))
-        if encoding:
-            attempted_encodings.add(encoding)
-            yield encoding
-
-    # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
-
-    # Check if it has an XML document with an encoding.
-    match = _xml_encoding_match.match(body_start)
-    if match:
-        encoding = _normalise_encoding(match.group(1).decode("ascii"))
-        if encoding and encoding not in attempted_encodings:
-            attempted_encodings.add(encoding)
-            yield encoding
-
-    # Check the HTTP Content-Type header for a character set.
-    if content_type:
-        content_match = _content_type_match.match(content_type)
-        if content_match:
-            encoding = _normalise_encoding(content_match.group(1))
-            if encoding and encoding not in attempted_encodings:
-                attempted_encodings.add(encoding)
-                yield encoding
-
-    # Finally, fallback to UTF-8, then windows-1252.
-    for fallback in ("utf-8", "cp1252"):
-        if fallback not in attempted_encodings:
-            yield fallback
-
-
-def decode_body(
-    body: bytes, uri: str, content_type: Optional[str] = None
-) -> Optional["etree.Element"]:
-    """
-    This uses lxml to parse the HTML document.
-
-    Args:
-        body: The HTML document, as bytes.
-        uri: The URI used to download the body.
-        content_type: The Content-Type header.
-
-    Returns:
-        The parsed HTML body, or None if an error occurred during processed.
-    """
-    # If there's no body, nothing useful is going to be found.
-    if not body:
-        return None
-
-    # The idea here is that multiple encodings are tried until one works.
-    # Unfortunately the result is never used and then LXML will decode the string
-    # again with the found encoding.
-    for encoding in get_html_media_encodings(body, content_type):
-        try:
-            body.decode(encoding)
-        except Exception:
-            pass
-        else:
-            break
-    else:
-        logger.warning("Unable to decode HTML body for %s", uri)
-        return None
-
-    from lxml import etree
-
-    # Create an HTML parser.
-    parser = etree.HTMLParser(recover=True, encoding=encoding)
-
-    # Attempt to parse the body. Returns None if the body was successfully
-    # parsed, but no tree was found.
-    return etree.fromstring(body, parser)
-
-
-def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
-    """
-    Calculate metadata for an HTML document.
-
-    This uses lxml to search the HTML document for Open Graph data.
-
-    Args:
-        tree: The parsed HTML document.
-        media_url: The URI used to download the body.
-
-    Returns:
-        The Open Graph response as a dictionary.
-    """
-
-    # if we see any image URLs in the OG response, then spider them
-    # (although the client could choose to do this by asking for previews of those
-    # URLs to avoid DoSing the server)
-
-    # "og:type"         : "video",
-    # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
-    # "og:site_name"    : "YouTube",
-    # "og:video:type"   : "application/x-shockwave-flash",
-    # "og:description"  : "Fun stuff happening here",
-    # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
-    # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
-    # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
-    # "og:video:width"  : "1280"
-    # "og:video:height" : "720",
-    # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
-
-    og: Dict[str, Optional[str]] = {}
-    for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
-        if "content" in tag.attrib:
-            # if we've got more than 50 tags, someone is taking the piss
-            if len(og) >= 50:
-                logger.warning("Skipping OG for page with too many 'og:' tags")
-                return {}
-            og[tag.attrib["property"]] = tag.attrib["content"]
-
-    # TODO: grab article: meta tags too, e.g.:
-
-    # "article:publisher" : "https://www.facebook.com/thethudonline" />
-    # "article:author" content="https://www.facebook.com/thethudonline" />
-    # "article:tag" content="baby" />
-    # "article:section" content="Breaking News" />
-    # "article:published_time" content="2016-03-31T19:58:24+00:00" />
-    # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
-
-    if "og:title" not in og:
-        # do some basic spidering of the HTML
-        title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
-        if title and title[0].text is not None:
-            og["og:title"] = title[0].text.strip()
-        else:
-            og["og:title"] = None
-
-    if "og:image" not in og:
-        # TODO: extract a favicon failing all else
-        meta_image = tree.xpath(
-            "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
-        )
-        if meta_image:
-            og["og:image"] = _rebase_url(meta_image[0], media_uri)
-        else:
-            # TODO: consider inlined CSS styles as well as width & height attribs
-            images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
-            images = sorted(
-                images,
-                key=lambda i: (
-                    -1 * float(i.attrib["width"]) * float(i.attrib["height"])
-                ),
-            )
-            if not images:
-                images = tree.xpath("//img[@src]")
-            if images:
-                og["og:image"] = images[0].attrib["src"]
-
-    if "og:description" not in og:
-        meta_description = tree.xpath(
-            "//*/meta"
-            "[translate(@name, 'DESCRIPTION', 'description')='description']"
-            "/@content"
-        )
-        if meta_description:
-            og["og:description"] = meta_description[0]
-        else:
-            og["og:description"] = _calc_description(tree)
-    elif og["og:description"]:
-        # This must be a non-empty string at this point.
-        assert isinstance(og["og:description"], str)
-        og["og:description"] = summarize_paragraphs([og["og:description"]])
-
-    # TODO: delete the url downloads to stop diskfilling,
-    # as we only ever cared about its OG
-    return og
-
-
-def _calc_description(tree: "etree.Element") -> Optional[str]:
-    """
-    Calculate a text description based on an HTML document.
-
-    Grabs any text nodes which are inside the <body/> tag, unless they are within
-    an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
-    if they are within a <script/> or <style/> tag.
-
-    This is a very very very coarse approximation to a plain text render of the page.
-
-    Args:
-        tree: The parsed HTML document.
-
-    Returns:
-        The plain text description, or None if one cannot be generated.
-    """
-    # We don't just use XPATH here as that is slow on some machines.
-
-    from lxml import etree
-
-    TAGS_TO_REMOVE = (
-        "header",
-        "nav",
-        "aside",
-        "footer",
-        "script",
-        "noscript",
-        "style",
-        etree.Comment,
-    )
-
-    # Split all the text nodes into paragraphs (by splitting on new
-    # lines)
-    text_nodes = (
-        re.sub(r"\s+", "\n", el).strip()
-        for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
-    )
-    return summarize_paragraphs(text_nodes)
-
-
-def _iterate_over_text(
-    tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
-) -> Generator[str, None, None]:
-    """Iterate over the tree returning text nodes in a depth first fashion,
-    skipping text nodes inside certain tags.
-    """
-    # This is basically a stack that we extend using itertools.chain.
-    # This will either consist of an element to iterate over *or* a string
-    # to be returned.
-    elements = iter([tree])
-    while True:
-        el = next(elements, None)
-        if el is None:
-            return
-
-        if isinstance(el, str):
-            yield el
-        elif el.tag not in tags_to_ignore:
-            # el.text is the text before the first child, so we can immediately
-            # return it if the text exists.
-            if el.text:
-                yield el.text
-
-            # We add to the stack all the elements children, interspersed with
-            # each child's tail text (if it exists). The tail text of a node
-            # is text that comes *after* the node, so we always include it even
-            # if we ignore the child node.
-            elements = itertools.chain(
-                itertools.chain.from_iterable(  # Basically a flatmap
-                    [child, child.tail] if child.tail else [child]
-                    for child in el.iterchildren()
-                ),
-                elements,
-            )
-
-
-def _rebase_url(url: str, base: str) -> str:
-    base_parts = list(urlparse.urlparse(base))
-    url_parts = list(urlparse.urlparse(url))
-    if not url_parts[0]:  # fix up schema
-        url_parts[0] = base_parts[0] or "http"
-    if not url_parts[1]:  # fix up hostname
-        url_parts[1] = base_parts[1]
-        if not url_parts[2].startswith("/"):
-            url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
-    return urlparse.urlunparse(url_parts)
-
-
 def _is_media(content_type: str) -> bool:
     return content_type.lower().startswith("image/")
 
@@ -940,68 +638,3 @@ def _is_html(content_type: str) -> bool:
 
 def _is_json(content_type: str) -> bool:
     return content_type.lower().startswith("application/json")
-
-
-def summarize_paragraphs(
-    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
-) -> Optional[str]:
-    """
-    Try to get a summary respecting first paragraph and then word boundaries.
-
-    Args:
-        text_nodes: The paragraphs to summarize.
-        min_size: The minimum number of words to include.
-        max_size: The maximum number of words to include.
-
-    Returns:
-        A summary of the text nodes, or None if that was not possible.
-    """
-
-    # TODO: Respect sentences?
-
-    description = ""
-
-    # Keep adding paragraphs until we get to the MIN_SIZE.
-    for text_node in text_nodes:
-        if len(description) < min_size:
-            text_node = re.sub(r"[\t \r\n]+", " ", text_node)
-            description += text_node + "\n\n"
-        else:
-            break
-
-    description = description.strip()
-    description = re.sub(r"[\t ]+", " ", description)
-    description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
-
-    # If the concatenation of paragraphs to get above MIN_SIZE
-    # took us over MAX_SIZE, then we need to truncate mid paragraph
-    if len(description) > max_size:
-        new_desc = ""
-
-        # This splits the paragraph into words, but keeping the
-        # (preceding) whitespace intact so we can easily concat
-        # words back together.
-        for match in re.finditer(r"\s*\S+", description):
-            word = match.group()
-
-            # Keep adding words while the total length is less than
-            # MAX_SIZE.
-            if len(word) + len(new_desc) < max_size:
-                new_desc += word
-            else:
-                # At this point the next word *will* take us over
-                # MAX_SIZE, but we also want to ensure that its not
-                # a huge word. If it is add it anyway and we'll
-                # truncate later.
-                if len(new_desc) < min_size:
-                    new_desc += word
-                break
-
-        # Double check that we're not over the limit
-        if len(new_desc) > max_size:
-            new_desc = new_desc[:max_size]
-
-        # We always add an ellipsis because at the very least
-        # we chopped mid paragraph.
-        description = new_desc.strip() + "…"
-    return description if description else None
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 446204dbe5..69ac8c3423 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 import logging
-from collections import defaultdict, namedtuple
+from collections import defaultdict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -69,9 +69,6 @@ state_groups_histogram = Histogram(
 )
 
 
-KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-
-
 EVICTION_TIMEOUT_SECONDS = 60 * 60
 
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3056e64ff5..7967011afd 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,10 +17,8 @@ import logging
 from abc import ABCMeta
 from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
 
-from synapse.storage.database import LoggingTransaction  # noqa: F401
-from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import make_in_list_sql_clause  # noqa: F401; noqa: F401
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
 
@@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0693d39006..2cacc7dd6c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -13,8 +13,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 import time
+import types
 from collections import defaultdict
 from sys import intern
 from time import monotonic as monotonic_time
@@ -53,6 +55,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -175,7 +178,7 @@ class LoggingDatabaseConnection:
     def rollback(self) -> None:
         self.conn.rollback()
 
-    def __enter__(self) -> "Connection":
+    def __enter__(self) -> "LoggingDatabaseConnection":
         self.conn.__enter__()
         return self
 
@@ -526,6 +529,12 @@ class DatabasePool:
         the function will correctly handle being aborted and retried half way
         through its execution.
 
+        Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
+        since they could be evaluated multiple times (which would produce an empty
+        result on the second or subsequent evaluation). Likewise, the closure of `func`
+        must not reference any generators.  This method attempts to detect such usage
+        and will log an error.
+
         Args:
             conn
             desc
@@ -536,6 +545,39 @@ class DatabasePool:
             **kwargs
         """
 
+        # Robustness check: ensure that none of the arguments are generators, since that
+        # will fail if we have to repeat the transaction.
+        # For now, we just log an error, and hope that it works on the first attempt.
+        # TODO: raise an exception.
+        for i, arg in enumerate(args):
+            if inspect.isgenerator(arg):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %i to function %s",
+                    i,
+                    func,
+                )
+        for name, val in kwargs.items():
+            if inspect.isgenerator(val):
+                logger.error(
+                    "Programming error: generator passed to new_transaction as "
+                    "argument %s to function %s",
+                    name,
+                    func,
+                )
+        # also check variables referenced in func's closure
+        if inspect.isfunction(func):
+            f = cast(types.FunctionType, func)
+            if f.__closure__:
+                for i, cell in enumerate(f.__closure__):
+                    if inspect.isgenerator(cell.cell_contents):
+                        logger.error(
+                            "Programming error: function %s references generator %s "
+                            "via its closure",
+                            f,
+                            f.__code__.co_freevars[i],
+                        )
+
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -896,6 +938,9 @@ class DatabasePool:
     ) -> None:
         """Executes an INSERT query on the named table.
 
+        The input is given as a list of dicts, with one dict per row.
+        Generally simple_insert_many_values should be preferred for new code.
+
         Args:
             table: string giving the table name
             values: dict of new column names and values for them
@@ -909,6 +954,9 @@ class DatabasePool:
     ) -> None:
         """Executes an INSERT query on the named table.
 
+        The input is given as a list of dicts, with one dict per row.
+        Generally simple_insert_many_values_txn should be preferred for new code.
+
         Args:
             txn: The transaction to use.
             table: string giving the table name
@@ -933,23 +981,66 @@ class DatabasePool:
             if k != keys[0]:
                 raise RuntimeError("All items must have the same keys")
 
+        return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
+
+    async def simple_insert_many_values(
+        self,
+        table: str,
+        keys: Collection[str],
+        values: Collection[Collection[Any]],
+        desc: str,
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+            desc: description of the transaction, for logging and metrics
+        """
+        await self.runInteraction(
+            desc, self.simple_insert_many_values_txn, table, keys, values
+        )
+
+    @staticmethod
+    def simple_insert_many_values_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keys: Collection[str],
+        values: Iterable[Iterable[Any]],
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        The input is given as a list of rows, where each row is a list of values.
+        (Actually any iterable is fine.)
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            keys: list of column names
+            values: for each row, a list of values in the same order as `keys`
+        """
+
         if isinstance(txn.database_engine, PostgresEngine):
             # We use `execute_values` as it can be a lot faster than `execute_batch`,
             # but it's only available on postgres.
             sql = "INSERT INTO %s (%s) VALUES ?" % (
                 table,
-                ", ".join(k for k in keys[0]),
+                ", ".join(k for k in keys),
             )
 
-            txn.execute_values(sql, vals, fetch=False)
+            txn.execute_values(sql, values, fetch=False)
         else:
             sql = "INSERT INTO %s (%s) VALUES(%s)" % (
                 table,
-                ", ".join(k for k in keys[0]),
-                ", ".join("?" for _ in keys[0]),
+                ", ".join(k for k in keys),
+                ", ".join("?" for _ in keys),
             )
 
-            txn.execute_batch(sql, vals)
+            txn.execute_batch(sql, values)
 
     async def simple_upsert(
         self,
@@ -1177,9 +1268,9 @@ class DatabasePool:
         self,
         table: str,
         key_names: Collection[str],
-        key_values: Collection[Iterable[Any]],
+        key_values: Collection[Collection[Any]],
         value_names: Collection[str],
-        value_values: Iterable[Iterable[Any]],
+        value_values: Collection[Collection[Any]],
         desc: str,
     ) -> None:
         """
@@ -1337,7 +1428,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: Literal[False] = False,
         desc: str = "simple_select_one",
     ) -> Dict[str, Any]:
@@ -1348,7 +1439,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: Literal[True] = True,
         desc: str = "simple_select_one",
     ) -> Optional[Dict[str, Any]]:
@@ -1358,7 +1449,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: bool = False,
         desc: str = "simple_select_one",
     ) -> Optional[Dict[str, Any]]:
@@ -1528,7 +1619,7 @@ class DatabasePool:
         self,
         table: str,
         keyvalues: Optional[Dict[str, Any]],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         desc: str = "simple_select_list",
     ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
@@ -1591,7 +1682,7 @@ class DatabasePool:
         table: str,
         column: str,
         iterable: Iterable[Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         keyvalues: Optional[Dict[str, Any]] = None,
         desc: str = "simple_select_many_batch",
         batch_size: int = 100,
@@ -1614,16 +1705,7 @@ class DatabasePool:
 
         results: List[Dict[str, Any]] = []
 
-        if not iterable:
-            return results
-
-        # iterables can not be sliced, so convert it to a list first
-        it_list = list(iterable)
-
-        chunks = [
-            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
-        ]
-        for chunk in chunks:
+        for chunk in batch_iter(iterable, batch_size):
             rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
@@ -1763,7 +1845,7 @@ class DatabasePool:
         txn: LoggingTransaction,
         table: str,
         keyvalues: Dict[str, Any],
-        retcols: Iterable[str],
+        retcols: Collection[str],
         allow_none: bool = False,
     ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1871,7 +1953,7 @@ class DatabasePool:
         self,
         table: str,
         column: str,
-        iterable: Iterable[Any],
+        iterable: Collection[Any],
         keyvalues: Dict[str, Any],
         desc: str,
     ) -> int:
@@ -1882,7 +1964,8 @@ class DatabasePool:
         Args:
             table: string giving the table name
             column: column name to test for inclusion against `iterable`
-            iterable: list
+            iterable: list of values to match against `column`. NB cannot be a generator
+                as it may be evaluated multiple times.
             keyvalues: dict of column names and values to select the rows with
             desc: description of the transaction, for logging and metrics
 
@@ -2055,7 +2138,7 @@ class DatabasePool:
         table: str,
         term: Optional[str],
         col: str,
-        retcols: Iterable[str],
+        retcols: Collection[str],
         desc="simple_search_list",
     ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9ff2d8d8c3..f024761ba7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@ import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -68,7 +68,7 @@ from .session import SessionStore
 from .signatures import SignatureStore
 from .state import StateStore
 from .stats import StatsStore
-from .stream import StreamStore
+from .stream import StreamWorkerStore
 from .tags import TagsStore
 from .transactions import TransactionWorkerStore
 from .ui_auth import UIAuthStore
@@ -87,7 +87,7 @@ class DataStore(
     RoomStore,
     RoomBatchStore,
     RegistrationStore,
-    StreamStore,
+    StreamWorkerStore,
     ProfileStore,
     PresenceStore,
     TransactionWorkerStore,
@@ -129,7 +129,12 @@ class DataStore(
     LockStore,
     SessionStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
@@ -143,11 +148,7 @@ class DataStore(
                 ("device_lists_outbound_pokes", "stream_id"),
             ],
         )
-        self._cross_signing_id_gen = StreamIdGenerator(
-            db_conn, "e2e_cross_signing_keys", "stream_id"
-        )
 
-        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
         self._group_updates_id_gen = StreamIdGenerator(
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index f8bec266ac..32a553fdd7 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage._base import db_to_json
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -34,13 +44,19 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class AccountDataWorkerStore(SQLBaseStore):
-    """This is an abstract base class where subclasses must implement
-    `get_max_account_data_stream_id` which can be called in the initializer.
-    """
+class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
-        self._instance_name = hs.get_instance_name()
+        # `_can_write_to_account_data` indicates whether the current worker is allowed
+        # to write account data. A value of `True` implies that `_account_data_id_gen`
+        # is an `AbstractStreamIdGenerator` and not just a tracker.
+        self._account_data_id_gen: AbstractStreamIdTracker
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_account_data = (
@@ -61,8 +77,6 @@ class AccountDataWorkerStore(SQLBaseStore):
                 writers=hs.config.worker.writers.account_data,
             )
         else:
-            self._can_write_to_account_data = True
-
             # We shouldn't be running in worker mode with SQLite, but its useful
             # to support it for unit tests.
             #
@@ -70,7 +84,8 @@ class AccountDataWorkerStore(SQLBaseStore):
             # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
             # updated over replication. (Multiple writers are not supported for
             # SQLite).
-            if hs.get_instance_name() in hs.config.worker.writers.account_data:
+            if self._instance_name in hs.config.worker.writers.account_data:
+                self._can_write_to_account_data = True
                 self._account_data_id_gen = StreamIdGenerator(
                     db_conn,
                     "room_account_data",
@@ -90,8 +105,6 @@ class AccountDataWorkerStore(SQLBaseStore):
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super().__init__(database, db_conn, hs)
-
     def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
@@ -113,7 +126,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             room_id string to per room account_data dicts.
         """
 
-        def get_account_data_for_user_txn(txn):
+        def get_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
             rows = self.db_pool.simple_select_list_txn(
                 txn,
                 "account_data",
@@ -132,7 +147,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 ["room_id", "account_data_type", "content"],
             )
 
-            by_room = {}
+            by_room: Dict[str, Dict[str, JsonDict]] = {}
             for row in rows:
                 room_data = by_room.setdefault(row["room_id"], {})
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
@@ -177,7 +192,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             A dict of the room account_data
         """
 
-        def get_account_data_for_room_txn(txn):
+        def get_account_data_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, JsonDict]:
             rows = self.db_pool.simple_select_list_txn(
                 txn,
                 "room_account_data",
@@ -207,7 +224,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             The room account_data for that type, or None if there isn't any set.
         """
 
-        def get_account_data_for_room_and_type_txn(txn):
+        def get_account_data_for_room_and_type_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[JsonDict]:
             content_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 table="room_account_data",
@@ -243,14 +262,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return []
 
-        def get_updated_global_account_data_txn(txn):
+        def get_updated_global_account_data_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str]]:
             sql = (
                 "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_updated_global_account_data", get_updated_global_account_data_txn
@@ -273,14 +294,16 @@ class AccountDataWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return []
 
-        def get_updated_room_account_data_txn(txn):
+        def get_updated_room_account_data_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[int, str, str, str]]:
             sql = (
                 "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
             txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
+            return cast(List[Tuple[int, str, str, str]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_updated_room_account_data", get_updated_room_account_data_txn
@@ -299,7 +322,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             mapping from room_id string to per room account_data dicts.
         """
 
-        def get_updated_account_data_for_user_txn(txn):
+        def get_updated_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
             sql = (
                 "SELECT account_data_type, content FROM account_data"
                 " WHERE user_id = ? AND stream_id > ?"
@@ -316,7 +341,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (user_id, stream_id))
 
-            account_data_by_room = {}
+            account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
             for row in txn:
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
@@ -353,12 +378,15 @@ class AccountDataWorkerStore(SQLBaseStore):
             )
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == TagAccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
         elif stream_name == AccountDataStream.NAME:
             self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
@@ -372,7 +400,8 @@ class AccountDataWorkerStore(SQLBaseStore):
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
 
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@@ -389,6 +418,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -431,6 +461,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
@@ -452,7 +483,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
     def _add_account_data_for_user(
         self,
-        txn,
+        txn: LoggingTransaction,
         next_id: int,
         user_id: str,
         account_data_type: str,
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..92c95a41d7 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -24,9 +24,8 @@ from synapse.appservice import (
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -58,7 +57,12 @@ def _make_exclusive_regex(
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.appservice.app_service_config_files
         )
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc6..0024348067 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
 )
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 0f56e10220..fd3fc298b3 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional
 from synapse.events.utils import prune_event_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.util import json_encoder
@@ -31,7 +35,12 @@ logger = logging.getLogger(__name__)
 
 
 class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if (
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index a6fd9f2636..f3881671fd 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -26,7 +26,6 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
-from synapse.storage.types import Connection
 from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 
@@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):
 
 
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
 
 class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
 
         # (user_id, access_token, ip,) -> last_seen
         self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b..3682cb6a81 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
     REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -668,7 +673,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
                 # There's a type mismatch here between how we want to type the row and
                 # what fetchone says it returns, but we silence it because we know that
                 # res can't be None.
-                res: Tuple[Optional[int]] = txn.fetchone()  # type: ignore[assignment]
+                res = cast(Tuple[Optional[int]], txn.fetchone())
                 if res[0] is None:
                     # this can only happen if the `device_inbox` table is empty, in which
                     # case we have no work to do.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d5a4a661cd..273adb61fd 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
@@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -101,7 +107,9 @@ class DeviceWorkerStore(SQLBaseStore):
             "count_devices_by_users", count_devices_by_users_txn, user_ids
         )
 
-    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+    async def get_device(
+        self, user_id: str, device_id: str
+    ) -> Optional[Dict[str, Any]]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
 
@@ -109,15 +117,35 @@ class DeviceWorkerStore(SQLBaseStore):
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to retrieve
         Returns:
-            A dict containing the device information
-        Raises:
-            StoreError: if the device is not found
+            A dict containing the device information, or `None` if the device does not
+            exist.
         """
         return await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
             desc="get_device",
+            allow_none=True,
+        )
+
+    async def get_device_opt(
+        self, user_id: str, device_id: str
+    ) -> Optional[Dict[str, Any]]:
+        """Retrieve a device. Only returns devices that are not marked as
+        hidden.
+
+        Args:
+            user_id: The ID of the user which owns the device
+            device_id: The ID of the device to retrieve
+        Returns:
+            A dict containing the device information, or None if the device does not exist.
+        """
+        return await self.db_pool.simple_select_one(
+            table="devices",
+            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
+            retcols=("user_id", "device_id", "display_name"),
+            desc="get_device",
+            allow_none=True,
         )
 
     async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
@@ -274,7 +302,9 @@ class DeviceWorkerStore(SQLBaseStore):
         # add the updated cross-signing keys to the results list
         for user_id, result in cross_signing_keys_by_user.items():
             result["user_id"] = user_id
-            # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+            results.append(("m.signing_key_update", result))
+            # also send the unstable version
+            # FIXME: remove this when enough servers have upgraded
             results.append(("org.matrix.signing_key_update", result))
 
         return now_stream_id, results
@@ -949,7 +979,12 @@ class DeviceWorkerStore(SQLBaseStore):
 
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -1081,7 +1116,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index a3442814d7..f76c6121e8 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,16 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
 from typing import Iterable, List, Optional, Tuple
 
+import attr
+
 from synapse.api.errors import SynapseError
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.types import RoomAlias
 from synapse.util.caches.descriptors import cached
 
-RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomAliasMapping:
+    room_id: str
+    room_alias: str
+    servers: List[str]
 
 
 class DirectoryWorkerStore(CacheInvalidationWorkerStore):
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index b15fb71e62..0cb48b9dd7 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -13,35 +13,71 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Optional
+from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
+
+from typing_extensions import Literal, TypedDict
 
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
+from synapse.types import JsonDict, JsonSerializable
 from synapse.util import json_encoder
 
 
+class RoomKey(TypedDict):
+    """`KeyBackupData` in the Matrix spec.
+
+    https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
+    """
+
+    first_message_index: int
+    forwarded_count: int
+    is_verified: bool
+    session_data: JsonSerializable
+
+
 class EndToEndRoomKeyStore(SQLBaseStore):
+    """The store for end to end room key backups.
+
+    See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
+
+    As per the spec, backups are identified by an opaque version string. Internally,
+    version identifiers are assigned using incrementing integers. Non-numeric version
+    strings are treated as if they do not exist, since we would have never issued them.
+    """
+
     async def update_e2e_room_key(
-        self, user_id, version, room_id, session_id, room_key
-    ):
+        self,
+        user_id: str,
+        version: str,
+        room_id: str,
+        session_id: str,
+        room_key: RoomKey,
+    ) -> None:
         """Replaces the encrypted E2E room key for a given session in a given backup
 
         Args:
-            user_id(str): the user whose backup we're setting
-            version(str): the version ID of the backup we're updating
-            room_id(str): the ID of the room whose keys we're setting
-            session_id(str): the session whose room_key we're setting
-            room_key(dict): the room_key being set
+            user_id: the user whose backup we're setting
+            version: the version ID of the backup we're updating
+            room_id: the ID of the room whose keys we're setting
+            session_id: the session whose room_key we're setting
+            room_key: the room_key being set
         Raises:
             StoreError
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            raise StoreError(404, "No backup with that version exists")
 
         await self.db_pool.simple_update_one(
             table="e2e_room_keys",
             keyvalues={
                 "user_id": user_id,
-                "version": version,
+                "version": version_int,
                 "room_id": room_id,
                 "session_id": session_id,
             },
@@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="update_e2e_room_key",
         )
 
-    async def add_e2e_room_keys(self, user_id, version, room_keys):
+    async def add_e2e_room_keys(
+        self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
+    ) -> None:
         """Bulk add room keys to a given backup.
 
         Args:
-            user_id (str): the user whose backup we're adding to
-            version (str): the version ID of the backup for the set of keys we're adding to
-            room_keys (iterable[(str, str, dict)]): the keys to add, in the form
-                (roomID, sessionID, keyData)
+            user_id: the user whose backup we're adding to
+            version: the version ID of the backup for the set of keys we're adding to
+            room_keys: the keys to add, in the form (roomID, sessionID, keyData)
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            raise StoreError(404, "No backup with that version exists")
 
         values = []
         for (room_id, session_id, room_key) in room_keys:
             values.append(
                 {
                     "user_id": user_id,
-                    "version": version,
+                    "version": version_int,
                     "room_id": room_id,
                     "session_id": session_id,
                     "first_message_index": room_key["first_message_index"],
@@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_e2e_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> Dict[
+        Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+    ]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup for the set of keys we're querying
-            room_id (str): Optional. the ID of the room whose keys we're querying, if any.
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup for the set of keys we're querying
+            room_id: Optional. the ID of the room whose keys we're querying, if any.
                 If not specified, we return the keys for all the rooms in the backup.
-            session_id (str): Optional. the session whose room_key we're querying, if any.
+            session_id: Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we return all the keys in this version of
                 the backup (or for the specified room)
 
         Returns:
-            A list of dicts giving the session_data and message metadata for
-            these room keys.
+            A dict giving the session_data and message metadata for these room keys.
+            `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
         """
 
         try:
-            version = int(version)
+            version_int = int(version)
         except ValueError:
             return {"rooms": {}}
 
-        keyvalues = {"user_id": user_id, "version": version}
+        keyvalues = {"user_id": user_id, "version": version_int}
         if room_id:
             keyvalues["room_id"] = room_id
             if session_id:
@@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="get_e2e_room_keys",
         )
 
-        sessions = {"rooms": {}}
+        sessions: Dict[
+            Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
+        ] = {"rooms": {}}
         for row in rows:
             room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
             room_entry["sessions"][row["session_id"]] = {
@@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return sessions
 
-    async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+    async def get_e2e_room_keys_multi(
+        self,
+        user_id: str,
+        version: str,
+        room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+    ) -> Dict[str, Dict[str, RoomKey]]:
         """Get multiple room keys at a time.  The difference between this function and
         get_e2e_room_keys is that this function can be used to retrieve
         multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         specific key.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup we're querying about
-            room_keys (dict[str, dict[str, iterable[str]]]): a map from
-                room ID -> {"session": [session ids]} indicating the session IDs
-                that we want to query
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup we're querying about
+            room_keys: a map from room ID -> {"sessions": [session ids]}
+                indicating the session IDs that we want to query
 
         Returns:
-           dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
+           A map of room IDs to session IDs to room key
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return {}
 
         return await self.db_pool.runInteraction(
             "get_e2e_room_keys_multi",
             self._get_e2e_room_keys_multi_txn,
             user_id,
-            version,
+            version_int,
             room_keys,
         )
 
     @staticmethod
-    def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+    def _get_e2e_room_keys_multi_txn(
+        txn: LoggingTransaction,
+        user_id: str,
+        version: int,
+        room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
+    ) -> Dict[str, Dict[str, RoomKey]]:
         if not room_keys:
             return {}
 
@@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         txn.execute(sql, params)
 
-        ret = {}
+        ret: Dict[str, Dict[str, RoomKey]] = {}
 
         for row in txn:
             room_id = row[0]
@@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             user_id: the user whose backup we're querying
             version: the version ID of the backup we're querying about
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return 0
 
         return await self.db_pool.simple_select_one_onecol(
             table="e2e_room_keys",
-            keyvalues={"user_id": user_id, "version": version},
+            keyvalues={"user_id": user_id, "version": version_int},
             retcol="COUNT(*)",
             desc="count_e2e_room_keys",
         )
 
     @trace
     async def delete_e2e_room_keys(
-        self, user_id, version, room_id=None, session_id=None
-    ):
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> None:
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
 
         Args:
-            user_id(str): the user whose backup we're deleting from
-            version(str): the version ID of the backup for the set of keys we're deleting
-            room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
+            user_id: the user whose backup we're deleting from
+            version: the version ID of the backup for the set of keys we're deleting
+            room_id: Optional. the ID of the room whose keys we're deleting, if any.
                 If not specified, we delete the keys for all the rooms in the backup.
-            session_id(str): Optional. the session whose room_key we're querying, if any.
+            session_id: Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we delete all the keys in this version of
                 the backup (or for the specified room)
-
-        Returns:
-            The deletion transaction
         """
+        try:
+            version_int = int(version)
+        except ValueError:
+            # Our versions are all ints so if we can't convert it to an integer,
+            # it doesn't exist.
+            return
 
-        keyvalues = {"user_id": user_id, "version": int(version)}
+        keyvalues = {"user_id": user_id, "version": version_int}
         if room_id:
             keyvalues["room_id"] = room_id
             if session_id:
@@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @staticmethod
-    def _get_current_version(txn, user_id):
+    def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
         txn.execute(
             "SELECT MAX(version) FROM e2e_room_keys_versions "
             "WHERE user_id=? AND deleted=0",
             (user_id,),
         )
-        row = txn.fetchone()
-        if not row:
+        # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
+        # be `NULL` when there are no available versions.
+        row = cast(Tuple[Optional[int]], txn.fetchone())
+        if row[0] is None:
             raise StoreError(404, "No current backup version")
         return row[0]
 
-    async def get_e2e_room_keys_version_info(self, user_id, version=None):
+    async def get_e2e_room_keys_version_info(
+        self, user_id: str, version: Optional[str] = None
+    ) -> JsonDict:
         """Get info metadata about a version of our room_keys backup.
 
         Args:
-            user_id(str): the user whose backup we're querying
-            version(str): Optional. the version ID of the backup we're querying about
+            user_id: the user whose backup we're querying
+            version: Optional. the version ID of the backup we're querying about
                 If missing, we return the information about the current version.
         Raises:
             StoreError: with code 404 if there are no e2e_room_keys_versions present
@@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 etag(int): tag of the keys in the backup
         """
 
-        def _get_e2e_room_keys_version_info_txn(txn):
+        def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
             if version is None:
                 this_version = self._get_current_version(txn, user_id)
             else:
@@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 except ValueError:
                     # Our versions are all ints so if we can't convert it to an integer,
                     # it isn't there.
-                    raise StoreError(404, "No row found")
+                    raise StoreError(404, "No backup with that version exists")
 
             result = self.db_pool.simple_select_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
                 retcols=("version", "algorithm", "auth_data", "etag"),
+                allow_none=False,
             )
+            assert result is not None  # see comment on `simple_select_one_txn`
             result["auth_data"] = db_to_json(result["auth_data"])
             result["version"] = str(result["version"])
             if result["etag"] is None:
@@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @trace
-    async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
+    async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
         """Atomically creates a new version of this user's e2e_room_keys store
         with the given version info.
 
         Args:
-            user_id(str): the user whose backup we're creating a version
-            info(dict): the info about the backup version to be created
+            user_id: the user whose backup we're creating a version
+            info: the info about the backup version to be created
 
         Returns:
             The newly created version ID
         """
 
-        def _create_e2e_room_keys_version_txn(txn):
+        def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
             txn.execute(
                 "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
                 (user_id,),
             )
-            current_version = txn.fetchone()[0]
+            current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
             if current_version is None:
-                current_version = "0"
+                current_version = 0
 
-            new_version = str(int(current_version) + 1)
+            new_version = current_version + 1
 
             self.db_pool.simple_insert_txn(
                 txn,
@@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 },
             )
 
-            return new_version
+            return str(new_version)
 
         return await self.db_pool.runInteraction(
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
@@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         self,
         user_id: str,
         version: str,
-        info: Optional[dict] = None,
+        info: Optional[JsonDict] = None,
         version_etag: Optional[int] = None,
     ) -> None:
         """Update a given backup version
@@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             version_etag: etag of the keys in the backup. If None, then the etag
                 is not updated.
         """
-        updatevalues = {}
+        updatevalues: Dict[str, object] = {}
 
         if info is not None and "auth_data" in info:
             updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
@@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             updatevalues["etag"] = version_etag
 
         if updatevalues:
-            await self.db_pool.simple_update(
+            try:
+                version_int = int(version)
+            except ValueError:
+                # Our versions are all ints so if we can't convert it to an integer,
+                # it doesn't exist.
+                raise StoreError(404, "No backup with that version exists")
+
+            await self.db_pool.simple_update_one(
                 table="e2e_room_keys_versions",
-                keyvalues={"user_id": user_id, "version": version},
+                keyvalues={"user_id": user_id, "version": version_int},
                 updatevalues=updatevalues,
                 desc="update_e2e_room_keys_version",
             )
@@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 or if the version requested doesn't exist.
         """
 
-        def _delete_e2e_room_keys_version_txn(txn):
+        def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
             if version is None:
                 this_version = self._get_current_version(txn, user_id)
-                if this_version is None:
-                    raise StoreError(404, "No current backup version")
             else:
-                this_version = version
+                try:
+                    this_version = int(version)
+                except ValueError:
+                    # Our versions are all ints so if we can't convert it to an integer,
+                    # it isn't there.
+                    raise StoreError(404, "No backup with that version exists")
 
             self.db_pool.simple_delete_txn(
                 txn,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b06c1dc45b..57b5ffbad3 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,19 +14,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 import attr
 from canonicaljson import encode_canonical_json
 
-from twisted.enterprise.adbapi import Connection
-
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
 
 
 class EndToEndKeyBackgroundStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
         )
 
 
-class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._allow_device_name_lookup_over_federation = (
@@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         # Build the result structure, un-jsonify the results, and add the
         # "unsigned" section
-        rv = {}
+        rv: Dict[str, Dict[str, JsonDict]] = {}
         for user_id, device_keys in results.items():
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
@@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             # add each cross-signing signature to the correct device in the result dict.
             for (user_id, key_id, device_id, signature) in cross_sigs_result:
                 target_device_result = result[user_id][device_id]
+                # We've only looked up cross-signatures for non-deleted devices with key
+                # data.
+                assert target_device_result is not None
+                assert target_device_result.keys is not None
                 target_device_signatures = target_device_result.keys.setdefault(
                     "signatures", {}
                 )
@@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_device_keys_txn(
-        self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+        self,
+        txn: LoggingTransaction,
+        query_list: Collection[Tuple[str, str]],
+        include_all_devices: bool = False,
+        include_deleted_devices: bool = False,
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
         """Get information on devices from the database
 
@@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return result
 
     def _get_e2e_cross_signing_signatures_for_devices_txn(
-        self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+        self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
     ) -> List[Tuple[str, str, str, str]]:
         """Get cross-signing signatures for a given list of devices
 
@@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
         txn.execute(signature_sql, signature_query_params)
-        return txn.fetchall()
+        return cast(
+            List[
+                Tuple[
+                    str,
+                    str,
+                    str,
+                    str,
+                ]
+            ],
+            txn.fetchall(),
+        )
 
     async def get_e2e_one_time_keys(
         self, user_id: str, device_id: str, key_ids: List[str]
@@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
         """
 
-        def _add_e2e_one_time_keys(txn):
+        def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("new_keys", new_keys)
@@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             A mapping from algorithm to number of keys for that algorithm.
         """
 
-        def _count_e2e_one_time_keys(txn):
+        def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
             sql = (
                 "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ?"
@@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         )
 
     def _set_e2e_fallback_keys_txn(
-        self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        fallback_keys: JsonDict,
     ) -> None:
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
@@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """Returns a user's cross-signing key.
 
         Args:
@@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id):
+    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             their user ID will map to None.
 
         """
-        return await self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             "get_bare_e2e_cross_signing_keys_bulk",
             self._get_bare_e2e_cross_signing_keys_bulk_txn,
             user_ids,
         )
 
+        # The `Optional` comes from the `@cachedList` decorator.
+        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
-        txn: Connection,
+        txn: LoggingTransaction,
         user_ids: Iterable[str],
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Dict[str, JsonDict]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_ids (list[str]): the users whose keys are being requested
+            txn: db connection
+            user_ids: the users whose keys are being requested
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  If a user's cross-signing keys were not found, their user
-                ID will not be in the dict.
+            Mapping from user ID to key type to key data.
+            If a user's cross-signing keys were not found, their user ID will not be in
+            the dict.
 
         """
-        result = {}
+        result: Dict[str, Dict[str, JsonDict]] = {}
 
         for user_chunk in batch_iter(user_ids, 100):
             clause, params = make_in_list_sql_clause(
@@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
                 user_id = row["user_id"]
                 key_type = row["keytype"]
                 key = db_to_json(row["keydata"])
-                user_info = result.setdefault(user_id, {})
-                user_info[key_type] = key
+                user_keys = result.setdefault(user_id, {})
+                user_keys[key_type] = key
 
         return result
 
     def _get_e2e_cross_signing_signatures_txn(
         self,
-        txn: Connection,
-        keys: Dict[str, Dict[str, dict]],
+        txn: LoggingTransaction,
+        keys: Dict[str, Optional[Dict[str, JsonDict]]],
         from_user_id: str,
-    ) -> Dict[str, Dict[str, dict]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing signatures made by a user on a set of keys.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
-                key data.  This dict will be modified to add signatures.
-            from_user_id (str): fetch the signatures made by this user
+            txn: db connection
+            keys: a map of user ID to key type to key data.
+                This dict will be modified to add signatures.
+            from_user_id: fetch the signatures made by this user
 
         Returns:
-            dict[str, dict[str, dict]]: mapping from user ID to key type to key
-                data.  The return value will be the same as the keys argument,
-                with the modifications included.
+            Mapping from user ID to key type to key data.
+            The return value will be the same as the keys argument, with the
+            modifications included.
         """
 
         # find out what cross-signing keys (a.k.a. devices) we need to get
         # signatures for.  This is a map of (user_id, device_id) to key type
         # (device_id is the key's public part).
-        devices = {}
+        devices: Dict[Tuple[str, str], str] = {}
 
-        for user_id, user_info in keys.items():
-            if user_info is None:
+        for user_id, user_keys in keys.items():
+            if user_keys is None:
                 continue
-            for key_type, key in user_info.items():
+            for key_type, key in user_keys.items():
                 device_id = None
                 for k in key["keys"].values():
                     device_id = k
+                # `key` ought to be a `CrossSigningKey`, whose .keys property is a
+                # dictionary with a single entry:
+                #     "algorithm:base64_public_key": "base64_public_key"
+                # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
+                assert isinstance(device_id, str)
                 devices[(user_id, device_id)] = key_type
 
         for batch in batch_iter(devices.keys(), size=100):
@@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
             # and add the signatures to the appropriate keys
             for row in rows:
-                key_id = row["key_id"]
-                target_user_id = row["target_user_id"]
-                target_device_id = row["target_device_id"]
+                key_id: str = row["key_id"]
+                target_user_id: str = row["target_user_id"]
+                target_device_id: str = row["target_device_id"]
                 key_type = devices[(target_user_id, target_device_id)]
                 # We need to copy everything, because the result may have come
                 # from the cache.  dict.copy only does a shallow copy, so we
                 # need to recursively copy the dicts that will be modified.
-                user_info = keys[target_user_id] = keys[target_user_id].copy()
-                target_user_key = user_info[key_type] = user_info[key_type].copy()
+                user_keys = keys[target_user_id]
+                # `user_keys` cannot be `None` because we only fetched signatures for
+                # users with keys
+                assert user_keys is not None
+                user_keys = keys[target_user_id] = user_keys.copy()
+
+                target_user_key = user_keys[key_type] = user_keys[key_type].copy()
                 if "signatures" in target_user_key:
                     signatures = target_user_key["signatures"] = target_user_key[
                         "signatures"
@@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Dict[str, dict]]]:
+    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def _get_all_user_signature_changes_for_remotes_txn(txn):
+        def _get_all_user_signature_changes_for_remotes_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT stream_id, from_user_id AS user_id
                 FROM user_signature_stream
@@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         @trace
         def _claim_e2e_one_time_key_simple(
-            txn, user_id: str, device_id: str, algorithm: str
+            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
         ) -> Optional[Tuple[str, str]]:
             """Claim OTK for device for DBs that don't support RETURNING.
 
@@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
         @trace
         def _claim_e2e_one_time_key_returning(
-            txn, user_id: str, device_id: str, algorithm: str
+            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
         ) -> Optional[Tuple[str, str]]:
             """Claim OTK for device for DBs that support RETURNING.
 
@@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             key_id, key_json = otk_row
             return f"{algorithm}:{key_id}", key_json
 
-        results = {}
+        results: Dict[str, Dict[str, Dict[str, str]]] = {}
         for user_id, device_id, algorithm in query_list:
             if self.database_engine.supports_returning:
                 # If we support RETURNING clause we can use a single query that
@@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
 
 
 class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._cross_signing_id_gen = StreamIdGenerator(
+            db_conn, "e2e_cross_signing_keys", "stream_id"
+        )
+
     async def set_e2e_device_keys(
         self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
     ) -> bool:
@@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         or the keys were already in the database.
         """
 
-        def _set_e2e_device_keys_txn(txn):
+        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
@@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         )
 
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
-        def delete_e2e_keys_by_device_txn(txn):
+        def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
             log_kv(
                 {
                     "message": "Deleting keys for device",
@@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
-    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
+    def _set_e2e_cross_signing_key_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        key_type: str,
+        key: JsonDict,
+        stream_id: int,
+    ) -> None:
         """Set a user's cross-signing key.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_id (str): the user to set the signing key for
-            key_type (str): the type of key that is being set: either 'master'
+            txn: db connection
+            user_id: the user to set the signing key for
+            key_type: the type of key that is being set: either 'master'
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
-            key (dict): the key data
-            stream_id (int)
+            key: the key data
+            stream_id
         """
         # the 'key' dict will look something like:
         # {
@@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
         )
 
-    async def set_e2e_cross_signing_key(self, user_id, key_type, key):
+    async def set_e2e_cross_signing_key(
+        self, user_id: str, key_type: str, key: JsonDict
+    ) -> None:
         """Set a user's cross-signing key.
 
         Args:
-            user_id (str): the user to set the user-signing key for
-            key_type (str): the type of cross-signing key to set
-            key (dict): the key data
+            user_id: the user to set the user-signing key for
+            key_type: the type of cross-signing key to set
+            key: the key data
         """
 
         async with self._cross_signing_id_gen.get_next() as stream_id:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..270b30800b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine
@@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception):
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -279,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             new_front = set()
             for chunk in batch_iter(front, 100):
                 # Pull the auth events either from the cache or DB.
-                to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+                to_fetch: List[str] = []  # Event IDs to fetch from DB
                 for event_id in chunk:
                     res = self._event_auth_cache.get(event_id)
                     if res is None:
@@ -606,8 +615,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             # currently walking, either from cache or DB.
             search, chunk = search[:-100], search[-100:]
 
-            found = []  # Results found  # type: List[Tuple[str, str, int]]
-            to_fetch = []  # Event IDs to fetch from DB  # type: List[str]
+            found: List[Tuple[str, str, int]] = []  # Results found
+            to_fetch: List[str] = []  # Event IDs to fetch from DB
             for _, event_id in chunk:
                 res = self._event_auth_cache.get(event_id)
                 if res is None:
@@ -1384,7 +1393,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         count = await self.db_pool.simple_select_one_onecol(
             table="federation_inbound_events_staging",
             keyvalues={"room_id": room_id},
-            retcol="COALESCE(COUNT(*), 0)",
+            retcol="COUNT(*)",
             desc="prune_staged_events_in_room_count",
         )
 
@@ -1476,9 +1485,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """Update the prometheus metrics for the inbound federation staging area."""
 
         def _get_stats_for_federation_staging_txn(txn):
-            txn.execute(
-                "SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
-            )
+            txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
             (count,) = txn.fetchone()
 
             txn.execute(
@@ -1514,7 +1521,12 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..a98e6b2593 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,14 +13,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
 
 import attr
-from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -30,29 +33,64 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
-DEFAULT_HIGHLIGHT_ACTION = [
+DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [
+    "notify",
+    {"set_tweak": "highlight", "value": False},
+]
+DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
     "notify",
     {"set_tweak": "sound", "value": "default"},
     {"set_tweak": "highlight"},
 ]
 
 
-class BasePushAction(TypedDict):
-    event_id: str
-    actions: List[Union[dict, str]]
-
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class HttpPushAction:
+    """
+    HttpPushAction instances include the information used to generate HTTP
+    requests to a push gateway.
+    """
 
-class HttpPushAction(BasePushAction):
+    event_id: str
     room_id: str
     stream_ordering: int
+    actions: List[Union[dict, str]]
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class EmailPushAction(HttpPushAction):
+    """
+    EmailPushAction instances include the information used to render an email
+    push notification.
+    """
+
     received_ts: Optional[int]
 
 
-def _serialize_action(actions, is_highlight):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPushAction(EmailPushAction):
+    """
+    UserPushAction instances include the necessary information to respond to
+    /notifications requests.
+    """
+
+    topological_ordering: int
+    highlight: bool
+    profile_tag: str
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class NotifCounts:
+    """
+    The per-user, per-room count of notifications. Used by sync and push.
+    """
+
+    notify_count: int
+    unread_count: int
+    highlight_count: int
+
+
+def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
     """Custom serializer for actions. This allows us to "compress" common actions.
 
     We use the fact that most users have the same actions for notifs (and for
@@ -70,7 +108,7 @@ def _serialize_action(actions, is_highlight):
     return json_encoder.encode(actions)
 
 
-def _deserialize_action(actions, is_highlight):
+def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]:
     """Custom deserializer for actions. This allows us to "compress" common actions"""
     if actions:
         return db_to_json(actions)
@@ -82,12 +120,17 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
-        self.stream_ordering_month_ago = None
-        self.stream_ordering_day_ago = None
+        self.stream_ordering_month_ago: Optional[int] = None
+        self.stream_ordering_day_ago: Optional[int] = None
 
         cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
         self._find_stream_orderings_for_times_txn(cur)
@@ -111,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         room_id: str,
         user_id: str,
         last_read_event_id: Optional[str],
-    ) -> Dict[str, int]:
+    ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
         for a given user in a given room after the given read receipt.
 
@@ -140,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     def _get_unread_counts_by_receipt_txn(
         self,
-        txn,
-        room_id,
-        user_id,
-        last_read_event_id,
-    ):
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        last_read_event_id: Optional[str],
+    ) -> NotifCounts:
         stream_ordering = None
 
         if last_read_event_id is not None:
-            stream_ordering = self.get_stream_id_for_event_txn(
+            stream_ordering = self.get_stream_id_for_event_txn(  # type: ignore[attr-defined]
                 txn,
                 last_read_event_id,
                 allow_none=True,
@@ -166,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 retcol="event_id",
             )
 
-            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)  # type: ignore[attr-defined]
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
-    def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
+    def _get_unread_counts_by_pos_txn(
+        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+    ) -> NotifCounts:
         sql = (
             "SELECT"
             "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
@@ -210,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 # for this row.
                 unread_count += row[1]
 
-        return {
-            "notify_count": notif_count,
-            "unread_count": unread_count,
-            "highlight_count": highlight_count,
-        }
+        return NotifCounts(
+            notify_count=notif_count,
+            unread_count=unread_count,
+            highlight_count=highlight_count,
+        )
 
     async def get_push_action_users_in_range(
-        self, min_stream_ordering, max_stream_ordering
-    ):
-        def f(txn):
+        self, min_stream_ordering: int, max_stream_ordering: int
+    ) -> List[str]:
+        def f(txn: LoggingTransaction) -> List[str]:
             sql = (
                 "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
                 " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
@@ -227,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
 
-        ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
-        return ret
+        return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
@@ -254,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
         # find rooms that have a read receipt in them and return the next
         # push actions
-        def get_after_receipt(txn):
+        def get_after_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool]]:
             # find rooms that have a read receipt in them and return the next
             # push actions
             sql = (
@@ -280,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
         after_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
@@ -289,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # There are rooms with push actions in them but you don't have a read receipt in
         # them e.g. rooms you've been invited to, so get push actions for rooms which do
         # not have read receipts in them too.
-        def get_no_receipt(txn):
+        def get_no_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool]]:
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
                 "   ep.highlight "
@@ -309,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
         no_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
         )
 
         notifs = [
-            {
-                "event_id": row[0],
-                "room_id": row[1],
-                "stream_ordering": row[2],
-                "actions": _deserialize_action(row[3], row[4]),
-            }
+            HttpPushAction(
+                event_id=row[0],
+                room_id=row[1],
+                stream_ordering=row[2],
+                actions=_deserialize_action(row[3], row[4]),
+            )
             for row in after_read_receipt + no_read_receipt
         ]
 
@@ -329,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # contain results from the first query, correctly ordered, followed
         # by results from the second query, but we want them all ordered
         # by stream_ordering, oldest first.
-        notifs.sort(key=lambda r: r["stream_ordering"])
+        notifs.sort(key=lambda r: r.stream_ordering)
 
         # Take only up to the limit. We have to stop at the limit because
         # one of the subqueries may have hit the limit.
@@ -359,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
         # find rooms that have a read receipt in them and return the most recent
         # push actions
-        def get_after_receipt(txn):
+        def get_after_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool, int]]:
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
                 "  ep.highlight, e.received_ts"
@@ -384,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
         after_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
@@ -393,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # There are rooms with push actions in them but you don't have a read receipt in
         # them e.g. rooms you've been invited to, so get push actions for rooms which do
         # not have read receipts in them too.
-        def get_no_receipt(txn):
+        def get_no_receipt(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, str, bool, int]]:
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
                 "   ep.highlight, e.received_ts"
@@ -413,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
         no_read_receipt = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
@@ -421,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         # Make a list of dicts from the two sets of results.
         notifs = [
-            {
-                "event_id": row[0],
-                "room_id": row[1],
-                "stream_ordering": row[2],
-                "actions": _deserialize_action(row[3], row[4]),
-                "received_ts": row[5],
-            }
+            EmailPushAction(
+                event_id=row[0],
+                room_id=row[1],
+                stream_ordering=row[2],
+                actions=_deserialize_action(row[3], row[4]),
+                received_ts=row[5],
+            )
             for row in after_read_receipt + no_read_receipt
         ]
 
@@ -435,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # contain results from the first query, correctly ordered, followed
         # by results from the second query, but we want them all ordered
         # by received_ts (most recent first)
-        notifs.sort(key=lambda r: -(r["received_ts"] or 0))
+        notifs.sort(key=lambda r: -(r.received_ts or 0))
 
         # Now return the first `limit`
         return notifs[:limit]
@@ -456,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             True if there may be push to process, False if there definitely isn't.
         """
 
-        def _get_if_maybe_push_in_range_for_user_txn(txn):
+        def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool:
             sql = """
                 SELECT 1 FROM event_push_actions
                 WHERE user_id = ? AND stream_ordering > ? AND notif = 1
@@ -490,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         # This is a helper function for generating the necessary tuple that
         # can be used to insert into the `event_push_actions_staging` table.
-        def _gen_entry(user_id, actions):
+        def _gen_entry(
+            user_id: str, actions: List[Union[dict, str]]
+        ) -> Tuple[str, str, str, int, int, int]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
-                _serialize_action(actions, is_highlight),  # actions column
+                _serialize_action(actions, bool(is_highlight)),  # actions column
                 notif,  # notif column
                 is_highlight,  # highlight column
                 int(count_as_unread),  # unread column
             )
 
-        def _add_push_actions_to_staging_txn(txn):
+        def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
             # We don't use simple_insert_many here to avoid the overhead
             # of generating lists of dicts.
 
@@ -530,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
 
         try:
-            res = await self.db_pool.simple_delete(
+            await self.db_pool.simple_delete(
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
             )
-            return res
         except Exception:
             # this method is called from an exception handler, so propagating
             # another exception here really isn't helpful - there's nothing
@@ -588,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     @staticmethod
-    def _find_first_stream_ordering_after_ts_txn(txn, ts):
+    def _find_first_stream_ordering_after_ts_txn(
+        txn: LoggingTransaction, ts: int
+    ) -> int:
         """
         Find the stream_ordering of the first event that was received on or
         after a given timestamp. This is relatively slow as there is no index
@@ -600,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         stream_ordering
 
         Args:
-            txn (twisted.enterprise.adbapi.Transaction):
-            ts (int): timestamp to search for
+            txn:
+            ts: timestamp to search for
 
         Returns:
-            int: stream ordering
+            The stream ordering
         """
         txn.execute("SELECT MAX(stream_ordering) FROM events")
-        max_stream_ordering = txn.fetchone()[0]
+        max_stream_ordering = cast(Tuple[Optional[int]], txn.fetchone())[0]
 
         if max_stream_ordering is None:
             return 0
@@ -663,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         return range_end
 
-    async def get_time_of_last_push_action_before(self, stream_ordering):
-        def f(txn):
+    async def get_time_of_last_push_action_before(
+        self, stream_ordering: int
+    ) -> Optional[int]:
+        def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
             sql = (
                 "SELECT e.received_ts"
                 " FROM event_push_actions AS ep"
@@ -674,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 " LIMIT 1"
             )
             txn.execute(sql, (stream_ordering,))
-            return txn.fetchone()
+            return cast(Optional[Tuple[int]], txn.fetchone())
 
         result = await self.db_pool.runInteraction(
             "get_time_of_last_push_action_before", f
@@ -682,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         return result[0] if result else None
 
     @wrap_as_background_process("rotate_notifs")
-    async def _rotate_notifs(self):
+    async def _rotate_notifs(self) -> None:
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
         self._doing_notif_rotation = True
@@ -700,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         finally:
             self._doing_notif_rotation = False
 
-    def _rotate_notifs_txn(self, txn):
+    def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
         """Archives older notifications into event_push_summary. Returns whether
         the archiving process has caught up or not.
         """
@@ -725,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         stream_row = txn.fetchone()
         if stream_row:
             (offset_stream_ordering,) = stream_row
+            assert self.stream_ordering_day_ago is not None
             rotate_to_stream_ordering = min(
                 self.stream_ordering_day_ago, offset_stream_ordering
             )
@@ -740,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # We have caught up iff we were limited by `stream_ordering_day_ago`
         return caught_up
 
-    def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+    def _rotate_notifs_before_txn(
+        self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+    ) -> None:
         old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
@@ -861,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
+        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+    ) -> None:
         """
         Purges old push actions for a user and room before a given
         stream_ordering.
@@ -910,7 +970,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -929,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
 
     async def get_push_actions_for_user(
-        self, user_id, before=None, limit=50, only_highlight=False
-    ):
-        def f(txn):
+        self,
+        user_id: str,
+        before: Optional[str] = None,
+        limit: int = 50,
+        only_highlight: bool = False,
+    ) -> List[UserPushAction]:
+        def f(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, int, int, str, bool, str, int]]:
             before_clause = ""
             if before:
                 before_clause = "AND epa.stream_ordering < ?"
@@ -958,32 +1029,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " LIMIT ?" % (before_clause,)
             )
             txn.execute(sql, args)
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(
+                List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()
+            )
 
         push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
-        for pa in push_actions:
-            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
-        return push_actions
+        return [
+            UserPushAction(
+                event_id=row[0],
+                room_id=row[1],
+                stream_ordering=row[2],
+                actions=_deserialize_action(row[4], row[5]),
+                received_ts=row[7],
+                topological_ordering=row[3],
+                highlight=row[5],
+                profile_tag=row[6],
+            )
+            for row in push_actions
+        ]
 
 
-def _action_has_highlight(actions):
+def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
     for action in actions:
-        try:
-            if action.get("set_tweak", None) == "highlight":
-                return action.get("value", True)
-        except AttributeError:
-            pass
+        if not isinstance(action, dict):
+            continue
+
+        if action.get("set_tweak", None) == "highlight":
+            return action.get("value", True)
 
     return False
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _EventPushSummary:
     """Summary of pending event push actions for a given user in a given room.
     Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
     """
 
-    unread_count = attr.ib(type=int)
-    stream_ordering = attr.ib(type=int)
-    old_user_id = attr.ib(type=str)
-    notif_count = attr.ib(type=int)
+    unread_count: int
+    stream_ordering: int
+    old_user_id: str
+    notif_count: int
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..dd255aefb9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -19,6 +19,7 @@ from collections import OrderedDict
 from typing import (
     TYPE_CHECKING,
     Any,
+    Collection,
     Dict,
     Generator,
     Iterable,
@@ -40,10 +41,13 @@ from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
@@ -94,7 +98,7 @@ class PersistEventsStore:
         hs: "HomeServer",
         db: DatabasePool,
         main_data_store: "DataStore",
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
     ):
         self.hs = hs
         self.db_pool = db
@@ -1319,14 +1323,13 @@ class PersistEventsStore:
 
         return [ec for ec in events_and_contexts if ec[0] not in to_remove]
 
-    def _store_event_txn(self, txn, events_and_contexts):
+    def _store_event_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: Collection[Tuple[EventBase, EventContext]],
+    ) -> None:
         """Insert new events into the event, event_json, redaction and
         state_events tables.
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]): events
-                we are persisting
         """
 
         if not events_and_contexts:
@@ -1339,46 +1342,58 @@ class PersistEventsStore:
             d.pop("redacted_because", None)
             return d
 
-        self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_values_txn(
             txn,
             table="event_json",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "internal_metadata": json_encoder.encode(
-                        event.internal_metadata.get_dict()
-                    ),
-                    "json": json_encoder.encode(event_dict(event)),
-                    "format_version": event.format_version,
-                }
+            keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
+            values=(
+                (
+                    event.event_id,
+                    event.room_id,
+                    json_encoder.encode(event.internal_metadata.get_dict()),
+                    json_encoder.encode(event_dict(event)),
+                    event.format_version,
+                )
                 for event, _ in events_and_contexts
-            ],
+            ),
         )
 
-        self.db_pool.simple_insert_many_txn(
+        self.db_pool.simple_insert_many_values_txn(
             txn,
             table="events",
-            values=[
-                {
-                    "instance_name": self._instance_name,
-                    "stream_ordering": event.internal_metadata.stream_ordering,
-                    "topological_ordering": event.depth,
-                    "depth": event.depth,
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "type": event.type,
-                    "processed": True,
-                    "outlier": event.internal_metadata.is_outlier(),
-                    "origin_server_ts": int(event.origin_server_ts),
-                    "received_ts": self._clock.time_msec(),
-                    "sender": event.sender,
-                    "contains_url": (
-                        "url" in event.content and isinstance(event.content["url"], str)
-                    ),
-                }
+            keys=(
+                "instance_name",
+                "stream_ordering",
+                "topological_ordering",
+                "depth",
+                "event_id",
+                "room_id",
+                "type",
+                "processed",
+                "outlier",
+                "origin_server_ts",
+                "received_ts",
+                "sender",
+                "contains_url",
+            ),
+            values=(
+                (
+                    self._instance_name,
+                    event.internal_metadata.stream_ordering,
+                    event.depth,  # topological_ordering
+                    event.depth,  # depth
+                    event.event_id,
+                    event.room_id,
+                    event.type,
+                    True,  # processed
+                    event.internal_metadata.is_outlier(),
+                    int(event.origin_server_ts),
+                    self._clock.time_msec(),
+                    event.sender,
+                    "url" in event.content and isinstance(event.content["url"], str),
+                )
                 for event, _ in events_and_contexts
-            ],
+            ),
         )
 
         # If we're persisting an unredacted event we go and ensure
@@ -1397,27 +1412,15 @@ class PersistEventsStore:
         )
         txn.execute(sql + clause, [False] + args)
 
-        state_events_and_contexts = [
-            ec for ec in events_and_contexts if ec[0].is_state()
-        ]
-
-        state_values = []
-        for event, _ in state_events_and_contexts:
-            vals = {
-                "event_id": event.event_id,
-                "room_id": event.room_id,
-                "type": event.type,
-                "state_key": event.state_key,
-            }
-
-            # TODO: How does this work with backfilling?
-            if hasattr(event, "replaces_state"):
-                vals["prev_state"] = event.replaces_state
-
-            state_values.append(vals)
-
-        self.db_pool.simple_insert_many_txn(
-            txn, table="state_events", values=state_values
+        self.db_pool.simple_insert_many_values_txn(
+            txn,
+            table="state_events",
+            keys=("event_id", "room_id", "type", "state_key"),
+            values=(
+                (event.event_id, event.room_id, event.type, event.state_key)
+                for event, _ in events_and_contexts
+                if event.is_state()
+            ),
         )
 
     def _store_rejected_events_txn(self, txn, events_and_contexts):
@@ -1780,10 +1783,14 @@ class PersistEventsStore:
         )
 
         if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
+            )
 
         if rel_type == RelationTypes.THREAD:
-            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+            txn.call_after(
+                self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
+            )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
@@ -1969,14 +1976,17 @@ class PersistEventsStore:
                 txn, self.store.get_retention_policy_for_room, (event.room_id,)
             )
 
-    def store_event_search_txn(self, txn, event, key, value):
+    def store_event_search_txn(
+        self, txn: LoggingTransaction, event: EventBase, key: str, value: str
+    ) -> None:
         """Add event to the search table
 
         Args:
-            txn (cursor):
-            event (EventBase):
-            key (str):
-            value (str):
+            txn: The database transaction.
+            event: The event being added to the search table.
+            key: A key describing the search value (one of "content.name",
+                "content.topic", or "content.body")
+            value: The value from the event's content.
         """
         self.store.store_search_entries_txn(
             txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c88fd35e7f..a68f14ba48 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
 
 import attr
 
@@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
@@ -83,7 +84,12 @@ class _CalculateChainCover:
 
 
 class EventsBackgroundUpdatesStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
@@ -234,12 +240,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         ################################################################################
 
-    async def _background_reindex_fields_sender(self, progress, batch_size):
+    async def _background_reindex_fields_sender(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        def reindex_txn(txn):
+        def reindex_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id, json FROM events"
                 " INNER JOIN event_json USING (event_id)"
@@ -301,12 +309,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
-    async def _background_reindex_origin_server_ts(self, progress, batch_size):
+    async def _background_reindex_origin_server_ts(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        def reindex_search_txn(txn):
+        def reindex_search_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
                 " WHERE ? <= stream_ordering AND stream_ordering < ?"
@@ -375,7 +385,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
-    async def _cleanup_extremities_bg_update(self, progress, batch_size):
+    async def _cleanup_extremities_bg_update(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to clean out extremities that should have been
         deleted previously.
 
@@ -396,12 +408,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         # have any descendants, but if they do then we should delete those
         # extremities.
 
-        def _cleanup_extremities_bg_update_txn(txn):
+        def _cleanup_extremities_bg_update_txn(txn: LoggingTransaction) -> int:
             # The set of extremity event IDs that we're checking this round
             original_set = set()
 
-            # A dict[str, set[str]] of event ID to their prev events.
-            graph = {}
+            # A dict[str, Set[str]] of event ID to their prev events.
+            graph: Dict[str, Set[str]] = {}
 
             # The set of descendants of the original set that are not rejected
             # nor soft-failed. Ancestors of these events should be removed
@@ -530,7 +542,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 room_ids = {row["room_id"] for row in rows}
                 for room_id in room_ids:
                     txn.call_after(
-                        self.get_latest_event_ids_in_room.invalidate, (room_id,)
+                        self.get_latest_event_ids_in_room.invalidate, (room_id,)  # type: ignore[attr-defined]
                     )
 
             self.db_pool.simple_delete_many_txn(
@@ -552,7 +564,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
             )
 
-            def _drop_table_txn(txn):
+            def _drop_table_txn(txn: LoggingTransaction) -> None:
                 txn.execute("DROP TABLE _extremities_to_check")
 
             await self.db_pool.runInteraction(
@@ -561,11 +573,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return num_handled
 
-    async def _redactions_received_ts(self, progress, batch_size):
+    async def _redactions_received_ts(self, progress: JsonDict, batch_size: int) -> int:
         """Handles filling out the `received_ts` column in redactions."""
         last_event_id = progress.get("last_event_id", "")
 
-        def _redactions_received_ts_txn(txn):
+        def _redactions_received_ts_txn(txn: LoggingTransaction) -> int:
             # Fetch the set of event IDs that we want to update
             sql = """
                 SELECT event_id FROM redactions
@@ -616,10 +628,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return count
 
-    async def _event_fix_redactions_bytes(self, progress, batch_size):
+    async def _event_fix_redactions_bytes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Undoes hex encoded censored redacted event JSON."""
 
-        def _event_fix_redactions_bytes_txn(txn):
+        def _event_fix_redactions_bytes_txn(txn: LoggingTransaction) -> None:
             # This update is quite fast due to new index.
             txn.execute(
                 """
@@ -644,11 +658,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return 1
 
-    async def _event_store_labels(self, progress, batch_size):
+    async def _event_store_labels(self, progress: JsonDict, batch_size: int) -> int:
         """Background update handler which will store labels for existing events."""
         last_event_id = progress.get("last_event_id", "")
 
-        def _event_store_labels_txn(txn):
+        def _event_store_labels_txn(txn: LoggingTransaction) -> int:
             txn.execute(
                 """
                 SELECT event_id, json FROM event_json
@@ -748,7 +762,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 ),
             )
 
-            return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn]  # type: ignore
+            return cast(
+                List[Tuple[str, str, JsonDict, bool, bool]],
+                [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn],
+            )
 
         results = await self.db_pool.runInteraction(
             desc="_rejected_events_metadata_get", func=get_rejected_events
@@ -906,7 +923,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     def _calculate_chain_cover_txn(
         self,
-        txn: Cursor,
+        txn: LoggingTransaction,
         last_room_id: str,
         last_depth: int,
         last_stream: int,
@@ -1017,10 +1034,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         PersistEventsStore._add_chain_cover_index(
             txn,
             self.db_pool,
-            self.event_chain_id_gen,
+            self.event_chain_id_gen,  # type: ignore[attr-defined]
             event_to_room_id,
             event_to_types,
-            event_to_auth_chain,
+            cast(Dict[str, Sequence[str]], event_to_auth_chain),
         )
 
         return _CalculateChainCover(
@@ -1040,7 +1057,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         """
         current_event_id = progress.get("current_event_id", "")
 
-        def purged_chain_cover_txn(txn) -> int:
+        def purged_chain_cover_txn(txn: LoggingTransaction) -> int:
             # The event ID from events will be null if the chain ID / sequence
             # number points to a purged event.
             sql = """
@@ -1175,14 +1192,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 # Iterate the parent IDs and invalidate caches.
                 for parent_id in {r[1] for r in relations_to_insert}:
                     cache_tuple = (parent_id,)
-                    self._invalidate_cache_and_stream(
-                        txn, self.get_relations_for_event, cache_tuple
+                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
+                        txn, self.get_relations_for_event, cache_tuple  # type: ignore[attr-defined]
                     )
-                    self._invalidate_cache_and_stream(
-                        txn, self.get_aggregation_groups_for_event, cache_tuple
+                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
+                        txn, self.get_aggregation_groups_for_event, cache_tuple  # type: ignore[attr-defined]
                     )
-                    self._invalidate_cache_and_stream(
-                        txn, self.get_thread_summary, cache_tuple
+                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
+                        txn, self.get_thread_summary, cache_tuple  # type: ignore[attr-defined]
                     )
 
             if results:
@@ -1214,7 +1231,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         """
         batch_size = max(batch_size, 1)
 
-        def process(txn: Cursor) -> int:
+        def process(txn: LoggingTransaction) -> int:
             last_stream = progress.get("last_stream", -(1 << 31))
             txn.execute(
                 """
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c7b660ac5a..8d4287045a 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1383,10 +1383,6 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_events_token(self) -> int:
-        """The current maximum token that events have reached"""
-        return self._stream_id_gen.get_current_token()
-
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index cf842803bc..cb9ee08fa8 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Union
+from typing import Optional, Tuple, Union, cast
 
 from canonicaljson import encode_canonical_json
 
@@ -63,7 +63,7 @@ class FilteringStore(SQLBaseStore):
 
             sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
             txn.execute(sql, (user_localpart,))
-            max_id = txn.fetchone()[0]  # type: ignore[index]
+            max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
             if max_id is None:
                 filter_id = 0
             else:
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index bb621df0dd..3f6086050b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -19,8 +19,7 @@ from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict):
 
 
 class GroupServerWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         database.updates.register_background_index_update(
             update_name="local_group_updates_index",
             index_name="local_group_updates_stream_id_index",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index a540f7fb26..bedacaf0d7 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.types import Connection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
@@ -54,7 +57,12 @@ class LockStore(SQLBaseStore):
     `last_renewed_ts` column with the current time.
     """
 
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._reactor = hs.get_reactor()
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1b076683f7..cbba356b4a 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -23,6 +23,7 @@ from typing import (
     Optional,
     Tuple,
     Union,
+    cast,
 )
 
 from synapse.storage._base import SQLBaseStore
@@ -220,7 +221,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 WHERE user_id = ?
             """
             txn.execute(sql, args)
-            count = txn.fetchone()[0]  # type: ignore[index]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = """
                 SELECT
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d901933ae4..1480a0f048 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
@@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     stats and prometheus metrics.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Read the extrems every 60 minutes
@@ -100,7 +105,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         def _count_messages(txn):
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
@@ -117,7 +122,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             like_clause = "%:" + self.hs.hostname
 
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                     AND sender LIKE ?
                 AND stream_ordering > ?
@@ -134,7 +139,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     async def count_daily_active_e2ee_rooms(self):
         def _count(txn):
             sql = """
-                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
@@ -156,7 +161,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
         def _count_messages(txn):
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
@@ -173,7 +178,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             like_clause = "%:" + self.hs.hostname
 
             sql = """
-                SELECT COALESCE(COUNT(*), 0) FROM events
+                SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                     AND sender LIKE ?
                 AND stream_ordering > ?
@@ -190,7 +195,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     async def count_daily_active_rooms(self):
         def _count(txn):
             sql = """
-                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
@@ -226,7 +231,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         Returns number of users seen in the past time_from period
         """
         sql = """
-            SELECT COALESCE(count(*), 0) FROM (
+            SELECT COUNT(*) FROM (
                 SELECT user_id FROM user_ips
                 WHERE last_seen > ?
                 GROUP BY user_id
@@ -253,7 +258,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             thirty_days_ago_in_secs = now - thirty_days_in_secs
 
             sql = """
-                SELECT platform, COALESCE(count(*), 0) FROM (
+                SELECT platform, COUNT(*) FROM (
                      SELECT
                         users.name, platform, users.creation_ts * 1000,
                         MAX(uip.last_seen)
@@ -291,7 +296,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                 results[row[0]] = row[1]
 
             sql = """
-                SELECT COALESCE(count(*), 0) FROM (
+                SELECT COUNT(*) FROM (
                     SELECT users.name, users.creation_ts * 1000,
                                                         MAX(uip.last_seen)
                     FROM users
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b5284e4f67..8f09dd8e87 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -16,8 +16,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    make_in_list_sql_clause,
+)
 from synapse.util.caches.descriptors import cached
+from synapse.util.threepids import canonicalise_email
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -30,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
@@ -49,7 +59,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         def _count_users(txn):
             # Exclude app service users
             sql = """
-                SELECT COALESCE(count(*), 0)
+                SELECT COUNT(*)
                 FROM monthly_active_users
                     LEFT JOIN users
                     ON monthly_active_users.user_id=users.name
@@ -76,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
         def _count_users_by_service(txn):
             sql = """
-                SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+                SELECT COALESCE(appservice_id, 'native'), COUNT(*)
                 FROM monthly_active_users
                 LEFT JOIN users ON monthly_active_users.user_id=users.name
                 GROUP BY appservice_id;
@@ -103,7 +113,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             : self.hs.config.server.max_mau_value
         ]:
             user_id = await self.hs.get_datastore().get_user_id_by_threepid(
-                tp["medium"], tp["address"]
+                tp["medium"], canonicalise_email(tp["address"])
             )
             if user_id:
                 users.append(user_id)
@@ -212,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
 
 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cc0eebdb46..cbf9ec38f7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
@@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -269,6 +269,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         """
         # Add user entries to the table, updating the presence_stream_id column if the user already
         # exists in the table.
+        presence_stream_id = self._presence_id_gen.get_current_token()
         await self.db_pool.simple_upsert_many(
             table="users_to_send_full_presence_to",
             key_names=("user_id",),
@@ -279,9 +280,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             # devices at different times, each device will receive full presence once - when
             # the presence stream ID in their sync token is less than the one in the table
             # for their user ID.
-            value_values=(
-                (self._presence_id_gen.get_current_token(),) for _ in user_ids
-            ),
+            value_values=[(presence_stream_id,) for _ in user_ids],
             desc="add_users_to_send_full_presence_to",
         )
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..e01c94930a 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -81,7 +81,12 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b73ce53c91..747b4f31df 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -22,7 +22,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -196,27 +196,6 @@ class PusherWorkerStore(SQLBaseStore):
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
-    @cachedList(
-        cached_method_name="get_if_user_has_pusher",
-        list_name="user_ids",
-        num_args=1,
-    )
-    async def get_if_users_have_pushers(
-        self, user_ids: Iterable[str]
-    ) -> Dict[str, bool]:
-        rows = await self.db_pool.simple_select_many_batch(
-            table="pushers",
-            column="user_name",
-            iterable=user_ids,
-            retcols=["user_name"],
-            desc="get_if_users_have_pushers",
-        )
-
-        result = {user_id: False for user_id in user_ids}
-        result.update({r["user_name"]: True for r in rows})
-
-        return result
-
     async def update_pusher_last_stream_ordering(
         self, app_id, pushkey, user_id, last_stream_ordering
     ) -> None:
@@ -515,7 +494,7 @@ class PusherStore(PusherWorkerStore):
                 # invalidate, since we the user might not have had a pusher before
                 await self.db_pool.runInteraction(
                     "add_pusher",
-                    self._invalidate_cache_and_stream,  # type: ignore
+                    self._invalidate_cache_and_stream,  # type: ignore[attr-defined]
                     self.get_if_user_has_pusher,
                     (user_id,),
                 )
@@ -524,7 +503,7 @@ class PusherStore(PusherWorkerStore):
         self, app_id: str, pushkey: str, user_id: str
     ) -> None:
         def delete_pusher_txn(txn, stream_id):
-            self._invalidate_cache_and_stream(  # type: ignore
+            self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
@@ -569,7 +548,7 @@ class PusherStore(PusherWorkerStore):
         pushers = list(await self.get_pushers_by_user_id(user_id))
 
         def delete_pushers_txn(txn, stream_ids):
-            self._invalidate_cache_and_stream(  # type: ignore
+            self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..bf0b903af2 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,29 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from twisted.internet import defer
 
+from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
@@ -36,7 +51,12 @@ logger = logging.getLogger(__name__)
 
 
 class ReceiptsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self._instance_name = hs.get_instance_name()
 
         if isinstance(database.engine, PostgresEngine):
@@ -78,17 +98,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
-    def get_max_receipt_stream_id(self):
-        """Get the current max stream ID for receipts stream
-
-        Returns:
-            int
-        """
+    def get_max_receipt_stream_id(self) -> int:
+        """Get the current max stream ID for receipts stream"""
         return self._receipts_id_gen.get_current_token()
 
     @cached()
-    async def get_users_with_read_receipts_in_room(self, room_id):
-        receipts = await self.get_receipts_for_room(room_id, "m.read")
+    async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
+        receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
         return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
@@ -119,7 +135,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=2)
-    async def get_receipts_for_user(self, user_id, receipt_type):
+    async def get_receipts_for_user(
+        self, user_id: str, receipt_type: str
+    ) -> Dict[str, str]:
         rows = await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +147,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return {row["room_id"]: row["event_id"] for row in rows}
 
-    async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
-        def f(txn):
+    async def get_receipts_for_user_with_orderings(
+        self, user_id: str, receipt_type: str
+    ) -> JsonDict:
+        def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
             sql = (
                 "SELECT rl.room_id, rl.event_id,"
                 " e.topological_ordering, e.stream_ordering"
@@ -209,10 +229,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(num_args=3, tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> List[JsonDict]:
         """See get_linearized_receipts_for_room"""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = (
                     "SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +270,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
         list_name="room_ids",
         num_args=3,
     )
-    async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def _get_linearized_receipts_for_rooms(
+        self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+    ) -> Dict[str, List[JsonDict]]:
         if not room_ids:
             return {}
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = """
                     SELECT * FROM receipts_linearized WHERE
@@ -323,7 +345,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             A dictionary of roomids to a list of receipts.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = """
                     SELECT * FROM receipts_linearized WHERE
@@ -379,7 +401,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return defer.succeed([])
 
-        def _get_users_sent_receipts_between_txn(txn):
+        def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
             sql = """
                 SELECT DISTINCT user_id FROM receipts_linearized
                 WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +441,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_receipts_txn(txn):
+        def get_all_updated_receipts_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, room_id, receipt_type, user_id, event_id, data
                 FROM receipts_linearized
@@ -446,8 +470,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_users_with_receipts_in_room(
         self, room_id: str, receipt_type: str, user_id: str
-    ):
-        if receipt_type != "m.read":
+    ) -> None:
+        if receipt_type != ReceiptTypes.READ:
             return
 
         res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +485,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+    def invalidate_caches_for_receipt(
+        self, room_id: str, receipt_type: str, user_id: str
+    ) -> None:
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +508,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def insert_linearized_receipt_txn(
-        self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
-    ):
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_id: str,
+        data: JsonDict,
+        stream_id: int,
+    ) -> Optional[int]:
         """Inserts a read-receipt into the database if it's newer than the current RR
 
-        Returns: int|None
+        Returns:
             None if the RR is older than the current RR
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
@@ -550,7 +583,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             lock=False,
         )
 
-        if receipt_type == "m.read" and stream_ordering is not None:
+        if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
             self._remove_old_push_actions_before_txn(
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
@@ -580,7 +613,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         else:
             # we need to points in graph -> linearized form.
             # TODO: Make this better.
-            def graph_to_linear(txn):
+            def graph_to_linear(txn: LoggingTransaction) -> str:
                 clause, args = make_in_list_sql_clause(
                     self.database_engine, "event_id", event_ids
                 )
@@ -634,11 +667,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return stream_id, max_persisted_id
 
     async def insert_graph_receipt(
-        self, room_id, receipt_type, user_id, event_ids, data
-    ):
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: JsonDict,
+    ) -> None:
         assert self._can_write_to_receipts
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
@@ -649,8 +687,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     def insert_graph_receipt_txn(
-        self, txn, room_id, receipt_type, user_id, event_ids, data
-    ):
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: JsonDict,
+    ) -> None:
         assert self._can_write_to_receipts
 
         txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e1ddf06916..4175c82a25 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -794,7 +794,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             yesterday = int(self._clock.time()) - (60 * 60 * 24)
 
             sql = """
-                SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+                SELECT user_type, COUNT(*) AS count FROM (
                     SELECT
                     CASE
                         WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
@@ -819,7 +819,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         def _count_users(txn):
             txn.execute(
                 """
-                SELECT COALESCE(COUNT(*), 0) FROM users
+                SELECT COUNT(*) FROM users
                 WHERE appservice_id IS NULL
             """
             )
@@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         Args:
             medium: threepid medium e.g. email
-            address: threepid address e.g. me@example.com
+            address: threepid address e.g. me@example.com. This must already be
+                in canonical form.
 
         Returns:
             The user ID or None if no user id/threepid mapping exists
@@ -1356,12 +1357,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             # Override type because the return type is only optional if
             # allow_none is True, and we don't want mypy throwing errors
             # about None not being indexable.
-            res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
-                txn,
-                "registration_tokens",
-                keyvalues={"token": token},
-                retcols=["pending", "completed"],
-            )  # type: ignore
+            res = cast(
+                Dict[str, Any],
+                self.db_pool.simple_select_one_txn(
+                    txn,
+                    "registration_tokens",
+                    keyvalues={"token": token},
+                    retcols=["pending", "completed"],
+                ),
+            )
 
             # Decrement pending and increment completed
             self.db_pool.simple_update_one_txn(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..4ff6aed253 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_relations_for_event(
         self,
         event_id: str,
+        room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
             aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
             the form `{"event_id": "..."}`.
         """
 
-        where_clause = ["relates_to_id = ?"]
-        where_args: List[Union[str, int]] = [event_id]
+        where_clause = ["relates_to_id = ?", "room_id = ?"]
+        where_args: List[Union[str, int]] = [event_id, room_id]
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_aggregation_groups_for_event(
         self,
         event_id: str,
+        room_id: str,
         event_type: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
             event_type: Only fetch events with this event type, if given.
             limit: Only fetch the `limit` groups.
             direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
             `type`, `key` and `count` fields.
         """
 
-        where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+        where_args: List[Union[str, int]] = [
+            event_id,
+            room_id,
+            RelationTypes.ANNOTATION,
+        ]
 
         if event_type:
             where_clause.append("type = ?")
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+    async def get_applicable_edit(
+        self, event_id: str, room_id: str
+    ) -> Optional[EventBase]:
         """Get the most recent edit (if any) that has happened for the given
         event.
 
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: The original event ID
+            room_id: The original event's room ID
 
         Returns:
             The most recent edit, if any.
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
             WHERE
                 relates_to_id = ?
                 AND relation_type = ?
+                AND edit.room_id = ?
                 AND edit.type = 'm.room.message'
             ORDER by edit.origin_server_ts DESC, edit.event_id DESC
             LIMIT 1
         """
 
         def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
-            txn.execute(sql, (event_id, RelationTypes.REPLACE))
+            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
 
     @cached()
     async def get_thread_summary(
-        self, event_id: str
+        self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
         """Get the number of threaded replies, the senders of those replies, and
         the latest reply (if any) for the given event.
 
         Args:
-            event_id: The original event ID
+            event_id: Summarize the thread related to this event ID.
+            room_id: The room the event belongs to.
 
         Returns:
             The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
                 INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
                 ORDER BY topological_ordering DESC, stream_ordering DESC
                 LIMIT 1
             """
 
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
             row = txn.fetchone()
             if row is None:
                 return 0, None
@@ -376,14 +390,16 @@ class RelationsWorkerStore(SQLBaseStore):
             latest_event_id = row[0]
 
             sql = """
-                SELECT COALESCE(COUNT(event_id), 0)
+                SELECT COUNT(event_id)
                 FROM event_relations
+                INNER JOIN events USING (event_id)
                 WHERE
                     relates_to_id = ?
+                    AND room_id = ?
                     AND relation_type = ?
             """
-            txn.execute(sql, (event_id, RelationTypes.THREAD))
-            count = txn.fetchone()[0]  # type: ignore[index]
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             return count, latest_event_id
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d694d852d..c0e837854a 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
 import logging
 from abc import abstractmethod
 from enum import Enum
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
+
+import attr
 
 from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.databases.main.search import SearchStore
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import IdGenerator
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -38,9 +54,10 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-RatelimitOverride = collections.namedtuple(
-    "RatelimitOverride", ("messages_per_second", "burst_count")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RatelimitOverride:
+    messages_per_second: int
+    burst_count: int
 
 
 class RoomSortOrder(Enum):
@@ -71,8 +88,13 @@ class RoomSortOrder(Enum):
     STATE_EVENTS = "state_events"
 
 
-class RoomWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class RoomWorkerStore(CacheInvalidationWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -83,7 +105,7 @@ class RoomWorkerStore(SQLBaseStore):
         room_creator_user_id: str,
         is_public: bool,
         room_version: RoomVersion,
-    ):
+    ) -> None:
         """Stores a room.
 
         Args:
@@ -111,7 +133,7 @@ class RoomWorkerStore(SQLBaseStore):
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
 
-    async def get_room(self, room_id: str) -> dict:
+    async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
         """Retrieve a room.
 
         Args:
@@ -136,7 +158,9 @@ class RoomWorkerStore(SQLBaseStore):
             A dict containing the room information, or None if the room is unknown.
         """
 
-        def get_room_with_stats_txn(txn, room_id):
+        def get_room_with_stats_txn(
+            txn: LoggingTransaction, room_id: str
+        ) -> Optional[Dict[str, Any]]:
             sql = """
                 SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
                   curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@@ -185,7 +209,7 @@ class RoomWorkerStore(SQLBaseStore):
             ignore_non_federatable: If true filters out non-federatable rooms
         """
 
-        def _count_public_rooms_txn(txn):
+        def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
             query_args = []
 
             if network_tuple:
@@ -195,6 +219,7 @@ class RoomWorkerStore(SQLBaseStore):
                         WHERE appservice_id = ? AND network_id = ?
                     """
                     query_args.append(network_tuple.appservice_id)
+                    assert network_tuple.network_id is not None
                     query_args.append(network_tuple.network_id)
                 else:
                     published_sql = """
@@ -208,7 +233,7 @@ class RoomWorkerStore(SQLBaseStore):
 
             sql = """
                 SELECT
-                    COALESCE(COUNT(*), 0)
+                    COUNT(*)
                 FROM (
                     %(published_sql)s
                 ) published
@@ -226,7 +251,7 @@ class RoomWorkerStore(SQLBaseStore):
             }
 
             txn.execute(sql, query_args)
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         return await self.db_pool.runInteraction(
             "count_public_rooms", _count_public_rooms_txn
@@ -235,11 +260,11 @@ class RoomWorkerStore(SQLBaseStore):
     async def get_room_count(self) -> int:
         """Retrieve the total number of rooms."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             sql = "SELECT count(*)  FROM rooms"
             txn.execute(sql)
-            row = txn.fetchone()
-            return row[0] or 0
+            row = cast(Tuple[int], txn.fetchone())
+            return row[0]
 
         return await self.db_pool.runInteraction("get_rooms", f)
 
@@ -251,7 +276,7 @@ class RoomWorkerStore(SQLBaseStore):
         bounds: Optional[Tuple[int, str]],
         forwards: bool,
         ignore_non_federatable: bool = False,
-    ):
+    ) -> List[Dict[str, Any]]:
         """Gets the largest public rooms (where largest is in terms of joined
         members, as tracked in the statistics table).
 
@@ -272,7 +297,7 @@ class RoomWorkerStore(SQLBaseStore):
         """
 
         where_clauses = []
-        query_args = []
+        query_args: List[Union[str, int]] = []
 
         if network_tuple:
             if network_tuple.appservice_id:
@@ -281,6 +306,7 @@ class RoomWorkerStore(SQLBaseStore):
                     WHERE appservice_id = ? AND network_id = ?
                 """
                 query_args.append(network_tuple.appservice_id)
+                assert network_tuple.network_id is not None
                 query_args.append(network_tuple.network_id)
             else:
                 published_sql = """
@@ -372,7 +398,9 @@ class RoomWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
 
-        def _get_largest_public_rooms_txn(txn):
+        def _get_largest_public_rooms_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, Any]]:
             txn.execute(sql, query_args)
 
             results = self.db_pool.cursor_to_dict(txn)
@@ -435,7 +463,7 @@ class RoomWorkerStore(SQLBaseStore):
         """
         # Filter room names by a string
         where_statement = ""
-        search_pattern = []
+        search_pattern: List[object] = []
         if search_term:
             where_statement = """
                 WHERE LOWER(state.name) LIKE ?
@@ -543,7 +571,9 @@ class RoomWorkerStore(SQLBaseStore):
             where_statement,
         )
 
-        def _get_rooms_paginate_txn(txn):
+        def _get_rooms_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
             # Add the search term into the WHERE clause
             # and execute the data query
             txn.execute(info_sql, search_pattern + [limit, start])
@@ -575,7 +605,7 @@ class RoomWorkerStore(SQLBaseStore):
             # Add the search term into the WHERE clause if present
             txn.execute(count_sql, search_pattern)
 
-            room_count = txn.fetchone()
+            room_count = cast(Tuple[int], txn.fetchone())
             return rooms, room_count[0]
 
         return await self.db_pool.runInteraction(
@@ -620,7 +650,7 @@ class RoomWorkerStore(SQLBaseStore):
             burst_count: How many actions that can be performed before being limited.
         """
 
-        def set_ratelimit_txn(txn):
+        def set_ratelimit_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="ratelimit_override",
@@ -643,7 +673,7 @@ class RoomWorkerStore(SQLBaseStore):
             user_id: user ID of the user
         """
 
-        def delete_ratelimit_txn(txn):
+        def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="ratelimit_override",
@@ -667,7 +697,7 @@ class RoomWorkerStore(SQLBaseStore):
         await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
 
     @cached()
-    async def get_retention_policy_for_room(self, room_id):
+    async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
         """Get the retention policy for a given room.
 
         If no retention policy has been found for this room, returns a policy defined
@@ -676,13 +706,15 @@ class RoomWorkerStore(SQLBaseStore):
         configuration).
 
         Args:
-            room_id (str): The ID of the room to get the retention policy of.
+            room_id: The ID of the room to get the retention policy of.
 
         Returns:
-            dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+            A dict containing "min_lifetime" and "max_lifetime" for this room.
         """
 
-        def get_retention_policy_for_room_txn(txn):
+        def get_retention_policy_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, Optional[int]]]:
             txn.execute(
                 """
                 SELECT min_lifetime, max_lifetime FROM room_retention
@@ -707,19 +739,23 @@ class RoomWorkerStore(SQLBaseStore):
                 "max_lifetime": self.config.retention.retention_default_max_lifetime,
             }
 
-        row = ret[0]
+        min_lifetime = ret[0]["min_lifetime"]
+        max_lifetime = ret[0]["max_lifetime"]
 
         # If one of the room's policy's attributes isn't defined, use the matching
         # attribute from the default policy.
         # The default values will be None if no default policy has been defined, or if one
         # of the attributes is missing from the default policy.
-        if row["min_lifetime"] is None:
-            row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
+        if min_lifetime is None:
+            min_lifetime = self.config.retention.retention_default_min_lifetime
 
-        if row["max_lifetime"] is None:
-            row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
+        if max_lifetime is None:
+            max_lifetime = self.config.retention.retention_default_max_lifetime
 
-        return row
+        return {
+            "min_lifetime": min_lifetime,
+            "max_lifetime": max_lifetime,
+        }
 
     async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
         """Retrieves all the local and remote media MXC URIs in a given room
@@ -731,7 +767,9 @@ class RoomWorkerStore(SQLBaseStore):
             The local and remote media as a lists of the media IDs.
         """
 
-        def _get_media_mxcs_in_room_txn(txn):
+        def _get_media_mxcs_in_room_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[str], List[str]]:
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
             local_media_mxcs = []
             remote_media_mxcs = []
@@ -757,7 +795,7 @@ class RoomWorkerStore(SQLBaseStore):
 
         logger.info("Quarantining media in room: %s", room_id)
 
-        def _quarantine_media_in_room_txn(txn):
+        def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
             return self._quarantine_media_txn(
                 txn, local_mxcs, remote_mxcs, quarantined_by
@@ -767,13 +805,11 @@ class RoomWorkerStore(SQLBaseStore):
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
-    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+    def _get_media_mxcs_in_room_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> Tuple[List[str], List[Tuple[str, str]]]:
         """Retrieves all the local and remote media MXC URIs in a given room
 
-        Args:
-            txn (cursor)
-            room_id (str)
-
         Returns:
             The local and remote media as a lists of tuples where the key is
             the hostname and the value is the media ID.
@@ -841,7 +877,7 @@ class RoomWorkerStore(SQLBaseStore):
         logger.info("Quarantining media: %s/%s", server_name, media_id)
         is_local = server_name == self.config.server.server_name
 
-        def _quarantine_media_by_id_txn(txn):
+        def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
             local_mxcs = [media_id] if is_local else []
             remote_mxcs = [(server_name, media_id)] if not is_local else []
 
@@ -863,7 +899,7 @@ class RoomWorkerStore(SQLBaseStore):
             quarantined_by: The ID of the user who made the quarantine request
         """
 
-        def _quarantine_media_by_user_txn(txn):
+        def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
             local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
             return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
 
@@ -871,7 +907,9 @@ class RoomWorkerStore(SQLBaseStore):
             "quarantine_media_by_user", _quarantine_media_by_user_txn
         )
 
-    def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+    def _get_media_ids_by_user_txn(
+        self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
+    ) -> List[str]:
         """Retrieves local media IDs by a given user
 
         Args:
@@ -900,7 +938,7 @@ class RoomWorkerStore(SQLBaseStore):
 
     def _quarantine_media_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         local_mxcs: List[str],
         remote_mxcs: List[Tuple[str, str]],
         quarantined_by: Optional[str],
@@ -928,12 +966,15 @@ class RoomWorkerStore(SQLBaseStore):
         # set quarantine
         if quarantined_by is not None:
             sql += "AND safe_from_quarantine = ?"
-            rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
+            txn.executemany(
+                sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
+            )
         # remove from quarantine
         else:
-            rows = [(quarantined_by, media_id) for media_id in local_mxcs]
+            txn.executemany(
+                sql, [(quarantined_by, media_id) for media_id in local_mxcs]
+            )
 
-        txn.executemany(sql, rows)
         # Note that a rowcount of -1 can be used to indicate no rows were affected.
         total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
 
@@ -951,7 +992,7 @@ class RoomWorkerStore(SQLBaseStore):
 
     async def get_rooms_for_retention_period_in_range(
         self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
-    ) -> Dict[str, dict]:
+    ) -> Dict[str, Dict[str, Optional[int]]]:
         """Retrieves all of the rooms within the given retention range.
 
         Optionally includes the rooms which don't have a retention policy.
@@ -971,7 +1012,9 @@ class RoomWorkerStore(SQLBaseStore):
             "min_lifetime" (int|None), and "max_lifetime" (int|None).
         """
 
-        def get_rooms_for_retention_period_in_range_txn(txn):
+        def get_rooms_for_retention_period_in_range_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, Optional[int]]]:
             range_conditions = []
             args = []
 
@@ -1050,11 +1093,14 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
 
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
-
         self.db_pool.updates.register_background_update_handler(
             "insert_room_retention",
             self._background_insert_retention,
@@ -1085,7 +1131,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_populate_rooms_creator_column,
         )
 
-    async def _background_insert_retention(self, progress, batch_size):
+    async def _background_insert_retention(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Retrieves a list of all rooms within a range and inserts an entry for each of
         them into the room_retention table.
         NULLs the property's columns if missing from the retention event in the room's
@@ -1095,7 +1143,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         last_room = progress.get("room_id", "")
 
-        def _background_insert_retention_txn(txn):
+        def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
             txn.execute(
                 """
                 SELECT state.room_id, state.event_id, events.json
@@ -1154,15 +1202,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
     async def _background_add_rooms_room_version_column(
-        self, progress: dict, batch_size: int
-    ):
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to go and add room version information to `rooms`
         table from `current_state_events` table.
         """
 
         last_room_id = progress.get("room_id", "")
 
-        def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+        def _background_add_rooms_room_version_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT room_id, json FROM current_state_events
                 INNER JOIN event_json USING (room_id, event_id)
@@ -1223,7 +1273,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
     async def _remove_tombstoned_rooms_from_directory(
-        self, progress, batch_size
+        self, progress: JsonDict, batch_size: int
     ) -> int:
         """Removes any rooms with tombstone events from the room directory
 
@@ -1233,7 +1283,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         last_room = progress.get("room_id", "")
 
-        def _get_rooms(txn):
+        def _get_rooms(txn: LoggingTransaction) -> List[str]:
             txn.execute(
                 """
                 SELECT room_id
@@ -1271,7 +1321,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return len(rooms)
 
     @abstractmethod
-    def set_room_is_public(self, room_id, is_public):
+    def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
         # this will need to be implemented if a background update is performed with
         # existing (tombstoned, public) rooms in the database.
         #
@@ -1318,7 +1368,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         32-bit integer field.
         """
 
-        def process(txn: Cursor) -> int:
+        def process(txn: LoggingTransaction) -> int:
             last_room = progress.get("last_room", "")
             txn.execute(
                 """
@@ -1375,15 +1425,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return 0
 
     async def _background_populate_rooms_creator_column(
-        self, progress: dict, batch_size: int
-    ):
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to go and add creator information to `rooms`
         table from `current_state_events` table.
         """
 
         last_room_id = progress.get("room_id", "")
 
-        def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
+        def _background_populate_rooms_creator_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT room_id, json FROM event_json
                 INNER JOIN rooms AS room USING (room_id)
@@ -1434,15 +1486,20 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
         return batch_size
 
 
-class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
+        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
 
     async def upsert_room_on_join(
         self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
-    ):
+    ) -> None:
         """Ensure that the room is stored in the table
 
         Called when we join a room over federation, and overwrites any room version
@@ -1488,7 +1545,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
     async def maybe_store_room_on_outlier_membership(
         self, room_id: str, room_version: RoomVersion
-    ):
+    ) -> None:
         """
         When we receive an invite or any other event over federation that may relate to a room
         we are not in, store the version of the room if we don't already know the room version.
@@ -1528,8 +1585,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         self.hs.get_notifier().on_new_replication_data()
 
     async def set_room_is_public_appservice(
-        self, room_id, appservice_id, network_id, is_public
-    ):
+        self, room_id: str, appservice_id: str, network_id: str, is_public: bool
+    ) -> None:
         """Edit the appservice/network specific public room list.
 
         Each appservice can have a number of published room lists associated
@@ -1538,11 +1595,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
         network.
 
         Args:
-            room_id (str)
-            appservice_id (str)
-            network_id (str)
-            is_public (bool): Whether to publish or unpublish the room from the
-                list.
+            room_id
+            appservice_id
+            network_id
+            is_public: Whether to publish or unpublish the room from the list.
         """
 
         if is_public:
@@ -1607,7 +1663,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             event_report: json list of information from event report
         """
 
-        def _get_event_report_txn(txn, report_id):
+        def _get_event_report_txn(
+            txn: LoggingTransaction, report_id: int
+        ) -> Optional[Dict[str, Any]]:
 
             sql = """
                 SELECT
@@ -1679,9 +1737,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             count: total number of event reports matching the filter criteria
         """
 
-        def _get_event_reports_paginate_txn(txn):
+        def _get_event_reports_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
             filters = []
-            args = []
+            args: List[object] = []
 
             if user_id:
                 filters.append("er.user_id LIKE ?")
@@ -1705,7 +1765,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 where_clause
             )
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = """
                 SELECT
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..cda80d6511 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
@@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 642560a70d..3cbaca21b5 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,13 +14,18 @@
 
 import logging
 import re
-from collections import namedtuple
 from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
 
+import attr
+
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
@@ -29,10 +34,15 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-SearchEntry = namedtuple(
-    "SearchEntry",
-    ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
-)
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SearchEntry:
+    key: str
+    value: str
+    event_id: str
+    room_id: str
+    stream_ordering: Optional[int]
+    origin_server_ts: int
 
 
 def _clean_value_for_search(value: str) -> str:
@@ -105,7 +115,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if not hs.config.server.enable_search:
@@ -358,7 +373,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index fa2c3b1feb..2fb3e65192 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 import collections.abc
 import logging
-from collections import namedtuple
 from typing import TYPE_CHECKING, Iterable, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
@@ -22,7 +21,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
@@ -39,24 +42,16 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
-class _GetStateGroupDelta(
-    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
-    """Return type of get_state_group_delta that implements __len__, which lets
-    us use the itrable flag when caching
-    """
-
-    __slots__ = []
-
-    def __len__(self):
-        return len(self.delta_ids) if self.delta_ids else 0
-
-
 # this inherits from EventsWorkerStore because it calls self.get_events
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers."""
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -182,11 +177,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             NotFoundError if the room is unknown
         """
         state_ids = await self.get_current_state_ids(room_id)
+
+        if not state_ids:
+            raise NotFoundError(f"Current state for room {room_id} is empty")
+
         create_id = state_ids.get((EventTypes.Create, ""))
 
         # If we can't find the create event, assume we've hit a dead end
         if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id,))
+            raise NotFoundError(f"No create event in current state for room {room_id}")
 
         # Retrieve the room's create event and return
         create_event = await self.get_event(create_id)
@@ -349,7 +348,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -536,5 +540,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 7f3624b128..188afec332 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -56,7 +56,9 @@ class StateDeltasStore(SQLBaseStore):
         prev_stream_id = int(prev_stream_id)
 
         # check we're not going backwards
-        assert prev_stream_id <= max_stream_id
+        assert (
+            prev_stream_id <= max_stream_id
+        ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}"
 
         if not self._curr_state_delta_stream_cache.has_any_entity_changed(
             prev_stream_id
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5d7b59d861..427ae1f649 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
 
 from typing_extensions import Counter
 
@@ -24,7 +24,11 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.errors import StoreError
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
@@ -96,7 +100,12 @@ class UserSortOrder(Enum):
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -117,7 +126,9 @@ class StatsStore(StateDeltasStore):
         self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
         self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
 
-    async def _populate_stats_process_users(self, progress, batch_size):
+    async def _populate_stats_process_users(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """
         This is a background update which regenerates statistics for users.
         """
@@ -129,7 +140,7 @@ class StatsStore(StateDeltasStore):
 
         last_user_id = progress.get("last_user_id", "")
 
-        def _get_next_batch(txn):
+        def _get_next_batch(txn: LoggingTransaction) -> List[str]:
             sql = """
                     SELECT DISTINCT name FROM users
                     WHERE name > ?
@@ -163,7 +174,9 @@ class StatsStore(StateDeltasStore):
 
         return len(users_to_work_on)
 
-    async def _populate_stats_process_rooms(self, progress, batch_size):
+    async def _populate_stats_process_rooms(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """This is a background update which regenerates statistics for rooms."""
         if not self.stats_enabled:
             await self.db_pool.updates._end_background_update(
@@ -173,7 +186,7 @@ class StatsStore(StateDeltasStore):
 
         last_room_id = progress.get("last_room_id", "")
 
-        def _get_next_batch(txn):
+        def _get_next_batch(txn: LoggingTransaction) -> List[str]:
             sql = """
                     SELECT DISTINCT room_id FROM current_state_events
                     WHERE room_id > ?
@@ -302,7 +315,7 @@ class StatsStore(StateDeltasStore):
             stream_id: Current position.
         """
 
-        def _bulk_update_stats_delta_txn(txn):
+        def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
             for stats_type, stats_updates in updates.items():
                 for stats_id, fields in stats_updates.items():
                     logger.debug(
@@ -334,7 +347,7 @@ class StatsStore(StateDeltasStore):
         stats_type: str,
         stats_id: str,
         fields: Dict[str, int],
-        complete_with_stream_id: Optional[int],
+        complete_with_stream_id: int,
         absolute_field_overrides: Optional[Dict[str, int]] = None,
     ) -> None:
         """
@@ -367,14 +380,14 @@ class StatsStore(StateDeltasStore):
 
     def _update_stats_delta_txn(
         self,
-        txn,
-        ts,
-        stats_type,
-        stats_id,
-        fields,
-        complete_with_stream_id,
-        absolute_field_overrides=None,
-    ):
+        txn: LoggingTransaction,
+        ts: int,
+        stats_type: str,
+        stats_id: str,
+        fields: Dict[str, int],
+        complete_with_stream_id: int,
+        absolute_field_overrides: Optional[Dict[str, int]] = None,
+    ) -> None:
         if absolute_field_overrides is None:
             absolute_field_overrides = {}
 
@@ -417,20 +430,23 @@ class StatsStore(StateDeltasStore):
         )
 
     def _upsert_with_additive_relatives_txn(
-        self, txn, table, keyvalues, absolutes, additive_relatives
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        absolutes: Dict[str, Any],
+        additive_relatives: Dict[str, int],
+    ) -> None:
         """Used to update values in the stats tables.
 
         This is basically a slightly convoluted upsert that *adds* to any
         existing rows.
 
         Args:
-            txn
-            table (str): Table name
-            keyvalues (dict[str, any]): Row-identifying key values
-            absolutes (dict[str, any]): Absolute (set) fields
-            additive_relatives (dict[str, int]): Fields that will be added onto
-                if existing row present.
+            table: Table name
+            keyvalues: Row-identifying key values
+            absolutes: Absolute (set) fields
+            additive_relatives: Fields that will be added onto if existing row present.
         """
         if self.database_engine.can_native_upsert:
             absolute_updates = [
@@ -486,20 +502,17 @@ class StatsStore(StateDeltasStore):
                 current_row.update(absolutes)
                 self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
 
-    async def _calculate_and_set_initial_state_for_room(
-        self, room_id: str
-    ) -> Tuple[dict, dict, int]:
+    async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
         """Calculate and insert an entry into room_stats_current.
 
         Args:
             room_id: The room ID under calculation.
-
-        Returns:
-            A tuple of room state, membership counts and stream position.
         """
 
-        def _fetch_current_state_stats(txn):
-            pos = self.get_room_max_stream_ordering()
+        def _fetch_current_state_stats(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
+            pos = self.get_room_max_stream_ordering()  # type: ignore[attr-defined]
 
             rows = self.db_pool.simple_select_many_txn(
                 txn,
@@ -519,7 +532,7 @@ class StatsStore(StateDeltasStore):
                 retcols=["event_id"],
             )
 
-            event_ids = [row["event_id"] for row in rows]
+            event_ids = cast(List[str], [row["event_id"] for row in rows])
 
             txn.execute(
                 """
@@ -533,15 +546,15 @@ class StatsStore(StateDeltasStore):
 
             txn.execute(
                 """
-                    SELECT COALESCE(count(*), 0) FROM current_state_events
+                    SELECT COUNT(*) FROM current_state_events
                     WHERE room_id = ?
                 """,
                 (room_id,),
             )
 
-            (current_state_events_count,) = txn.fetchone()
+            current_state_events_count = cast(Tuple[int], txn.fetchone())[0]
 
-            users_in_room = self.get_users_in_room_txn(txn, room_id)
+            users_in_room = self.get_users_in_room_txn(txn, room_id)  # type: ignore[attr-defined]
 
             return (
                 event_ids,
@@ -561,7 +574,7 @@ class StatsStore(StateDeltasStore):
             "get_initial_state_for_room", _fetch_current_state_stats
         )
 
-        state_event_map = await self.get_events(event_ids, get_prev_content=False)
+        state_event_map = await self.get_events(event_ids, get_prev_content=False)  # type: ignore[attr-defined]
 
         room_state = {
             "join_rules": None,
@@ -617,8 +630,10 @@ class StatsStore(StateDeltasStore):
             },
         )
 
-    async def _calculate_and_set_initial_state_for_user(self, user_id):
-        def _calculate_and_set_initial_state_for_user_txn(txn):
+    async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
+        def _calculate_and_set_initial_state_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, int]:
             pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
 
             txn.execute(
@@ -629,7 +644,7 @@ class StatsStore(StateDeltasStore):
                 """,
                 (user_id,),
             )
-            (count,) = txn.fetchone()
+            count = cast(Tuple[int], txn.fetchone())[0]
             return count, pos
 
         joined_rooms, pos = await self.db_pool.runInteraction(
@@ -673,7 +688,9 @@ class StatsStore(StateDeltasStore):
             users that exist given this query
         """
 
-        def get_users_media_usage_paginate_txn(txn):
+        def get_users_media_usage_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             filters = []
             args = [self.hs.config.server.server_name]
 
@@ -728,7 +745,7 @@ class StatsStore(StateDeltasStore):
                 sql_base=sql_base,
             )
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = """
                 SELECT
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..319464b1fa 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -34,11 +34,11 @@ what sort order was used:
     - topological tokems: "t%d-%d", where the integers map to the topological
       and stream ordering columns respectively.
 """
-import abc
+
 import logging
-from collections import namedtuple
 from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
 
+import attr
 from frozendict import frozendict
 
 from twisted.internet import defer
@@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_in_list_sql_clause,
 )
@@ -73,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological"
 
 
 # Used as return values for pagination APIs
-_EventDictReturn = namedtuple(
-    "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
-)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventDictReturn:
+    event_id: str
+    topological_ordering: Optional[int]
+    stream_ordering: int
 
 
 def generate_pagination_where_clause(
@@ -333,13 +336,13 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     return " AND ".join(clauses), args
 
 
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
-    """This is an abstract base class where subclasses must implement
-    `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
-    which can be called in the initializer.
-    """
-
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
@@ -371,13 +374,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         self._stream_order_on_start = self.get_room_max_stream_ordering()
 
-    @abc.abstractmethod
     def get_room_max_stream_ordering(self) -> int:
-        raise NotImplementedError()
+        """Get the stream_ordering of regular events that we have committed up to
+
+        Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
+        """
+        return self._stream_id_gen.get_current_token()
 
-    @abc.abstractmethod
     def get_room_min_stream_ordering(self) -> int:
-        raise NotImplementedError()
+        """Get the stream_ordering of backfilled events that we have committed up to
+
+        Backfilled events use *negative* stream orderings, so this returns the
+        minimum negative stream id such that all stream ids greater than or
+        equal to it have been successfully persisted.
+        """
+        return self._backfill_id_gen.get_current_token()
 
     def get_room_max_token(self) -> RoomStreamToken:
         """Get a `RoomStreamToken` that marks the current maximum persisted
@@ -819,7 +831,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         for event, row in zip(events, rows):
             stream = row.stream_ordering
             if topo_order and row.topological_ordering:
-                topo = row.topological_ordering
+                topo: Optional[int] = row.topological_ordering
             else:
                 topo = None
             internal = event.internal_metadata
@@ -1343,11 +1355,3 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             retcol="instance_name",
             desc="get_name_from_instance_id",
         )
-
-
-class StreamStore(StreamWorkerStore):
-    def get_room_max_stream_ordering(self) -> int:
-        return self._stream_id_gen.get_current_token()
-
-    def get_room_min_stream_ordering(self) -> int:
-        return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 8f510de53d..c8e508a910 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,11 +15,13 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Tuple, cast
 
+from synapse.replication.tcp.streams import TagAccountDataStream
 from synapse.storage._base import db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -204,6 +206,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The next account data ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -230,6 +233,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             The next account data ID.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None:
             sql = (
@@ -258,6 +262,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             next_id: The the revision to advance to.
         """
         assert self._can_write_to_account_data
+        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         txn.call_after(
             self._account_data_stream_cache.entity_has_changed, user_id, next_id
@@ -287,6 +292,21 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 # than the id that the client has.
                 pass
 
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
+                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
 
 class TagsStore(TagsWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..6c299cafa5 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -13,16 +13,19 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
 
 import attr
 from canonicaljson import encode_canonical_json
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
@@ -35,16 +38,6 @@ db_binary_type = memoryview
 logger = logging.getLogger(__name__)
 
 
-_TransactionRow = namedtuple(
-    "_TransactionRow",
-    ("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
-)
-
-_UpdateTransactionRow = namedtuple(
-    "_TransactionRow", ("response_code", "response_json")
-)
-
-
 class DestinationSortOrder(Enum):
     """Enum to define the sorting method used when returning destinations."""
 
@@ -71,7 +64,12 @@ class DestinationRetryTimings:
 
 
 class TransactionWorkerStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -82,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         now = self._clock.time_msec()
         month_ago = now - 30 * 24 * 60 * 60 * 1000
 
-        def _cleanup_transactions_txn(txn):
+        def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
             txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
 
         await self.db_pool.runInteraction(
@@ -112,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             origin,
         )
 
-    def _get_received_txn_response(self, txn, transaction_id, origin):
+    def _get_received_txn_response(
+        self, txn: LoggingTransaction, transaction_id: str, origin: str
+    ) -> Optional[Tuple[int, JsonDict]]:
         result = self.db_pool.simple_select_one_txn(
             txn,
             table="received_transactions",
@@ -187,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         return result
 
     def _get_destination_retry_timings(
-        self, txn, destination: str
+        self, txn: LoggingTransaction, destination: str
     ) -> Optional[DestinationRetryTimings]:
         result = self.db_pool.simple_select_one_txn(
             txn,
@@ -222,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         """
 
         if self.database_engine.can_native_upsert:
-            return await self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "set_destination_retry_timings",
                 self._set_destination_retry_timings_native,
                 destination,
@@ -232,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
                 db_autocommit=True,  # Safe as its a single upsert
             )
         else:
-            return await self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "set_destination_retry_timings",
                 self._set_destination_retry_timings_emulated,
                 destination,
@@ -242,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             )
 
     def _set_destination_retry_timings_native(
-        self, txn, destination, failure_ts, retry_last_ts, retry_interval
-    ):
+        self,
+        txn: LoggingTransaction,
+        destination: str,
+        failure_ts: Optional[int],
+        retry_last_ts: int,
+        retry_interval: int,
+    ) -> None:
         assert self.database_engine.can_native_upsert
 
         # Upsert retry time interval if retry_interval is zero (i.e. we're
@@ -273,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         )
 
     def _set_destination_retry_timings_emulated(
-        self, txn, destination, failure_ts, retry_last_ts, retry_interval
-    ):
+        self,
+        txn: LoggingTransaction,
+        destination: str,
+        failure_ts: Optional[int],
+        retry_last_ts: int,
+        retry_interval: int,
+    ) -> None:
         self.database_engine.lock_table(txn, "destinations")
 
         # We need to be careful here as the data may have changed from under us
@@ -384,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             last_successful_stream_ordering: the stream_ordering of the most
                 recent successfully-sent PDU
         """
-        return await self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             "destinations",
             keyvalues={"destination": destination},
             values={"last_successful_stream_ordering": last_successful_stream_ordering},
@@ -525,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             else:
                 order = "ASC"
 
-            args = []
+            args: List[object] = []
             where_statement = ""
             if destination:
                 args.extend(["%" + destination.lower() + "%"])
@@ -534,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             sql_base = f"FROM destinations {where_statement} "
             sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = f"""
                 SELECT destination, retry_last_ts, retry_interval, failure_ts,
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 340ca9e47d..a1a1a6a14a 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -225,11 +225,14 @@ class UIAuthWorkerStore(SQLBaseStore):
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
     ):
         # Get the current value.
-        result: Dict[str, Any] = self.db_pool.simple_select_one_txn(  # type: ignore
-            txn,
-            table="ui_auth_sessions",
-            keyvalues={"session_id": session_id},
-            retcols=("serverdict",),
+        result = cast(
+            Dict[str, Any],
+            self.db_pool.simple_select_one_txn(
+                txn,
+                table="ui_auth_sessions",
+                keyvalues={"session_id": session_id},
+                retcols=("serverdict",),
+            ),
         )
 
         # Update it and add it back to the database.
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e98a45b6af..0f9b8575d3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -32,11 +32,14 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.types import Connection
 from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
 
@@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ) -> None:
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 50d08094d5..2a3d47185a 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 66  # remember to update the list below when updating
+SCHEMA_VERSION = 67  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -50,6 +50,9 @@ Changes in SCHEMA_VERSION = 65:
 Changes in SCHEMA_VERSION = 66:
     - Queries on state_key columns are now disambiguated (ie, the codebase can handle
       the `events` table having a `state_key` column).
+
+Changes in SCHEMA_VERSION = 67:
+    - state_events.prev_state is no longer written to.
 """
 
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 4ff3013908..b8112e1c05 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -74,8 +74,6 @@ class IdGenerator:
 def _load_current_id(
     db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
 ) -> int:
-    # debug logging for https://github.com/matrix-org/synapse/issues/7968
-    logger.info("initialising stream generator for %s(%s)", table, column)
     cur = db_conn.cursor(txn_name="_load_current_id")
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
@@ -86,7 +84,9 @@ def _load_current_id(
     (val,) = result
     cur.close()
     current_id = int(val) if val else step
-    return (max if step > 0 else min)(current_id, step)
+    res = (max if step > 0 else min)(current_id, step)
+    logger.info("Initialising stream generator for %s(%s): %i", table, column, res)
+    return res
 
 
 class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
diff --git a/synapse/types.py b/synapse/types.py
index fb72f19343..42aeaf6270 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -15,7 +15,6 @@
 import abc
 import re
 import string
-from collections import namedtuple
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -59,9 +58,11 @@ StateKey = Tuple[str, str]
 StateMap = Mapping[StateKey, T]
 MutableStateMap = MutableMapping[StateKey, T]
 
-# the type of a JSON-serialisable dict. This could be made stronger, but it will
-# do for now.
+# JSON types. These could be made stronger, but will do for now.
+# A JSON-serialisable dict.
 JsonDict = Dict[str, Any]
+# A JSON-serialisable object.
+JsonSerializable = object
 
 
 # Note that this seems to require inheriting *directly* from Interface in order
@@ -225,8 +226,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
     localpart = attr.ib(type=str)
     domain = attr.ib(type=str)
 
-    # Because this class is a namedtuple of strings and booleans, it is deeply
-    # immutable.
+    # Because this is a frozen class, it is deeply immutable.
     def __copy__(self):
         return self
 
@@ -706,16 +706,18 @@ class PersistedEventPosition:
         return RoomStreamToken(None, self.stream)
 
 
-class ThirdPartyInstanceID(
-    namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThirdPartyInstanceID:
+    appservice_id: Optional[str]
+    network_id: Optional[str]
+
     # Deny iteration because it will bite you if you try to create a singleton
     # set by:
     #    users = set(user)
     def __iter__(self):
         raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
 
-    # Because this class is a namedtuple of strings, it is deeply immutable.
+    # Because this class is a frozen class, it is deeply immutable.
     def __copy__(self):
         return self
 
@@ -723,22 +725,18 @@ class ThirdPartyInstanceID(
         return self
 
     @classmethod
-    def from_string(cls, s):
+    def from_string(cls, s: str) -> "ThirdPartyInstanceID":
         bits = s.split("|", 2)
         if len(bits) != 2:
             raise SynapseError(400, "Invalid ID %r" % (s,))
 
         return cls(appservice_id=bits[0], network_id=bits[1])
 
-    def to_string(self):
+    def to_string(self) -> str:
         return "%s|%s" % (self.appservice_id, self.network_id)
 
     __str__ = to_string
 
-    @classmethod
-    def create(cls, appservice_id, network_id):
-        return cls(appservice_id=appservice_id, network_id=network_id)
-
 
 @attr.s(slots=True)
 class ReadReceipt:
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 95f23e27b6..f157132210 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -14,9 +14,8 @@
 
 import json
 import logging
-import re
 import typing
-from typing import Any, Callable, Dict, Generator, Optional, Pattern
+from typing import Any, Callable, Dict, Generator, Optional
 
 import attr
 from frozendict import frozendict
@@ -35,9 +34,6 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-_WILDCARD_RUN = re.compile(r"([\?\*]+)")
-
-
 def _reject_invalid_json(val: Any) -> None:
     """Do not allow Infinity, -Infinity, or NaN values in JSON."""
     raise ValueError("Invalid JSON value: '%s'" % val)
@@ -185,56 +181,3 @@ def log_failure(
     if not consumeErrors:
         return failure
     return None
-
-
-def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern:
-    """Converts a glob to a compiled regex object.
-
-    Args:
-        glob: pattern to match
-        word_boundary: If True, the pattern will be allowed to match at word boundaries
-           anywhere in the string. Otherwise, the pattern is anchored at the start and
-           end of the string.
-
-    Returns:
-        compiled regex pattern
-    """
-
-    # Patterns with wildcards must be simplified to avoid performance cliffs
-    # - The glob `?**?**?` is equivalent to the glob `???*`
-    # - The glob `???*` is equivalent to the regex `.{3,}`
-    chunks = []
-    for chunk in _WILDCARD_RUN.split(glob):
-        # No wildcards? re.escape()
-        if not _WILDCARD_RUN.match(chunk):
-            chunks.append(re.escape(chunk))
-            continue
-
-        # Wildcards? Simplify.
-        qmarks = chunk.count("?")
-        if "*" in chunk:
-            chunks.append(".{%d,}" % qmarks)
-        else:
-            chunks.append(".{%d}" % qmarks)
-
-    res = "".join(chunks)
-
-    if word_boundary:
-        res = re_word_boundary(res)
-    else:
-        # \A anchors at start of string, \Z at end of string
-        res = r"\A" + res + r"\Z"
-
-    return re.compile(res, re.IGNORECASE)
-
-
-def re_word_boundary(r: str) -> str:
-    """
-    Adds word boundary characters to the start and end of an
-    expression to require that the match occur as a whole word,
-    but do so respecting the fact that strings starting or ending
-    with non-word characters will change word boundaries.
-    """
-    # we can't use \b as it chokes on unicode. however \W seems to be okay
-    # as shorthand for [^0-9A-Za-z_].
-    return r"(^|\W)%s(\W|$)" % (r,)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 20ce294209..150a04b53e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import collections
 import inspect
 import itertools
@@ -30,9 +31,11 @@ from typing import (
     Iterator,
     Optional,
     Set,
+    Tuple,
     TypeVar,
     Union,
     cast,
+    overload,
 )
 
 import attr
@@ -55,7 +58,26 @@ logger = logging.getLogger(__name__)
 _T = TypeVar("_T")
 
 
-class ObservableDeferred(Generic[_T]):
+class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
+    """Abstract base class defining the consumer interface of ObservableDeferred"""
+
+    __slots__ = ()
+
+    @abc.abstractmethod
+    def observe(self) -> "defer.Deferred[_T]":
+        """Add a new observer for this ObservableDeferred
+
+        This returns a brand new deferred that is resolved when the underlying
+        deferred is resolved. Interacting with the returned deferred does not
+        effect the underlying deferred.
+
+        Note that the returned Deferred doesn't follow the Synapse logcontext rules -
+        you will probably want to `make_deferred_yieldable` it.
+        """
+        ...
+
+
+class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
@@ -234,6 +256,59 @@ def yieldable_gather_results(
     ).addErrback(unwrapFirstError)
 
 
+T1 = TypeVar("T1")
+T2 = TypeVar("T2")
+T3 = TypeVar("T3")
+
+
+@overload
+def gather_results(
+    deferredList: Tuple[()], consumeErrors: bool = ...
+) -> "defer.Deferred[Tuple[()]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple["defer.Deferred[T1]"],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple[
+        "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
+    ],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2, T3]]":
+    ...
+
+
+def gather_results(  # type: ignore[misc]
+    deferredList: Tuple["defer.Deferred[T1]", ...],
+    consumeErrors: bool = False,
+) -> "defer.Deferred[Tuple[T1, ...]]":
+    """Combines a tuple of `Deferred`s into a single `Deferred`.
+
+    Wraps `defer.gatherResults` to provide type annotations that support heterogenous
+    lists of `Deferred`s.
+    """
+    # The `type: ignore[misc]` above suppresses
+    # "Overloaded function implementation cannot produce return type of signature 1/2/3"
+    deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
+    return deferred.addCallback(tuple)
+
+
 @attr.s(slots=True)
 class _LinearizerEntry:
     # The number of things executing.
@@ -352,7 +427,7 @@ class Linearizer:
 
         logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
-        new_defer = make_deferred_yieldable(defer.Deferred())
+        new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
 
         def cb(_r: None) -> "defer.Deferred[None]":
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index 470f4f91a5..e325f44da3 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -76,6 +76,7 @@ class CachedCall(Generic[TV]):
 
         # Fire off the callable now if this is our first time
         if not self._deferred:
+            assert self._callable is not None
             self._deferred = run_in_background(self._callable)
 
             # we will never need the callable again, so make sure it can be GCed
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index eb96f7e665..3f11a2f9dd 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -69,7 +69,6 @@ try:
         sizer.exclude_refs((), None, "")
         return sizer.asizeof(val, limit=100 if recurse else 0)
 
-
 except ImportError:
 
     def _get_size_of(val: Any, *, recurse: bool = True) -> int:
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 88ccf44337..a3eb5f741b 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,19 +12,37 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Generic,
+    Iterable,
+    Optional,
+    TypeVar,
+)
 
 import attr
 
 from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import (
+    active_span,
+    start_active_span,
+    start_active_span_follows_from,
+)
 from synapse.util import Clock
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
 from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
+if TYPE_CHECKING:
+    import opentracing
+
 # the type of the key in the cache
 KV = TypeVar("KV")
 
@@ -54,6 +72,20 @@ class ResponseCacheContext(Generic[KV]):
     """
 
 
+@attr.s(auto_attribs=True)
+class ResponseCacheEntry:
+    result: AbstractObservableDeferred
+    """The (possibly incomplete) result of the operation.
+
+    Note that we continue to store an ObservableDeferred even after the operation
+    completes (rather than switching to an immediate value), since that makes it
+    easier to cache Failure results.
+    """
+
+    opentracing_span_context: "Optional[opentracing.SpanContext]"
+    """The opentracing span which generated/is generating the result"""
+
+
 class ResponseCache(Generic[KV]):
     """
     This caches a deferred response. Until the deferred completes it will be
@@ -63,10 +95,7 @@ class ResponseCache(Generic[KV]):
     """
 
     def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
-        # This is poorly-named: it includes both complete and incomplete results.
-        # We keep complete results rather than switching to absolute values because
-        # that makes it easier to cache Failure results.
-        self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
+        self._result_cache: Dict[KV, ResponseCacheEntry] = {}
 
         self.clock = clock
         self.timeout_sec = timeout_ms / 1000.0
@@ -75,56 +104,63 @@ class ResponseCache(Generic[KV]):
         self._metrics = register_cache("response_cache", name, self, resizable=False)
 
     def size(self) -> int:
-        return len(self.pending_result_cache)
+        return len(self._result_cache)
 
     def __len__(self) -> int:
         return self.size()
 
-    def get(self, key: KV) -> Optional[defer.Deferred]:
-        """Look up the given key.
+    def keys(self) -> Iterable[KV]:
+        """Get the keys currently in the result cache
 
-        Returns a new Deferred (which also doesn't follow the synapse
-        logcontext rules). You will probably want to make_deferred_yieldable the result.
+        Returns both incomplete entries, and (if the timeout on this cache is non-zero),
+        complete entries which are still in the cache.
 
-        If there is no entry for the key, returns None.
+        Note that the returned iterator is not safe in the face of concurrent execution:
+        behaviour is undefined if `wrap` is called during iteration.
+        """
+        return self._result_cache.keys()
+
+    def _get(self, key: KV) -> Optional[ResponseCacheEntry]:
+        """Look up the given key.
 
         Args:
-            key: key to get/set in the cache
+            key: key to get in the cache
 
         Returns:
-            None if there is no entry for this key; otherwise a deferred which
-            resolves to the result.
+            The entry for this key, if any; else None.
         """
-        result = self.pending_result_cache.get(key)
-        if result is not None:
+        entry = self._result_cache.get(key)
+        if entry is not None:
             self._metrics.inc_hits()
-            return result.observe()
+            return entry
         else:
             self._metrics.inc_misses()
             return None
 
     def _set(
-        self, context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]"
-    ) -> "defer.Deferred[RV]":
+        self,
+        context: ResponseCacheContext[KV],
+        deferred: "defer.Deferred[RV]",
+        opentracing_span_context: "Optional[opentracing.SpanContext]",
+    ) -> ResponseCacheEntry:
         """Set the entry for the given key to the given deferred.
 
         *deferred* should run its callbacks in the sentinel logcontext (ie,
         you should wrap normal synapse deferreds with
         synapse.logging.context.run_in_background).
 
-        Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
-        You will probably want to make_deferred_yieldable the result.
-
         Args:
             context: Information about the cache miss
             deferred: The deferred which resolves to the result.
+            opentracing_span_context: An opentracing span wrapping the calculation
 
         Returns:
-            A new deferred which resolves to the actual result.
+            The cache entry object.
         """
         result = ObservableDeferred(deferred, consumeErrors=True)
         key = context.cache_key
-        self.pending_result_cache[key] = result
+        entry = ResponseCacheEntry(result, opentracing_span_context)
+        self._result_cache[key] = entry
 
         def on_complete(r: RV) -> RV:
             # if this cache has a non-zero timeout, and the callback has not cleared
@@ -132,18 +168,18 @@ class ResponseCache(Generic[KV]):
             # its removal later.
             if self.timeout_sec and context.should_cache:
                 self.clock.call_later(
-                    self.timeout_sec, self.pending_result_cache.pop, key, None
+                    self.timeout_sec, self._result_cache.pop, key, None
                 )
             else:
                 # otherwise, remove the result immediately.
-                self.pending_result_cache.pop(key, None)
+                self._result_cache.pop(key, None)
             return r
 
-        # make sure we do this *after* adding the entry to pending_result_cache,
+        # make sure we do this *after* adding the entry to result_cache,
         # in case the result is already complete (in which case flipping the order would
         # leave us with a stuck entry in the cache).
         result.addBoth(on_complete)
-        return result.observe()
+        return entry
 
     async def wrap(
         self,
@@ -189,20 +225,41 @@ class ResponseCache(Generic[KV]):
         Returns:
             The result of the callback (from the cache, or otherwise)
         """
-        result = self.get(key)
-        if not result:
+        entry = self._get(key)
+        if not entry:
             logger.debug(
                 "[%s]: no cached result for [%s], calculating new one", self._name, key
             )
             context = ResponseCacheContext(cache_key=key)
             if cache_context:
                 kwargs["cache_context"] = context
-            d = run_in_background(callback, *args, **kwargs)
-            result = self._set(context, d)
-        elif not isinstance(result, defer.Deferred) or result.called:
+
+            span_context: Optional[opentracing.SpanContext] = None
+
+            async def cb() -> RV:
+                # NB it is important that we do not `await` before setting span_context!
+                nonlocal span_context
+                with start_active_span(f"ResponseCache[{self._name}].calculate"):
+                    span = active_span()
+                    if span:
+                        span_context = span.context
+                    return await callback(*args, **kwargs)
+
+            d = run_in_background(cb)
+            entry = self._set(context, d, span_context)
+            return await make_deferred_yieldable(entry.result.observe())
+
+        result = entry.result.observe()
+        if result.called:
             logger.info("[%s]: using completed cached result for [%s]", self._name, key)
         else:
             logger.info(
                 "[%s]: using incomplete cached result for [%s]", self._name, key
             )
-        return await make_deferred_yieldable(result)
+
+        span_context = entry.opentracing_span_context
+        with start_active_span_follows_from(
+            f"ResponseCache[{self._name}].wait",
+            contexts=(span_context,) if span_context else (),
+        ):
+            return await make_deferred_yieldable(result)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index de2adacd70..46771a401b 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -142,6 +142,7 @@ class BackgroundFileConsumer:
 
     def wait(self) -> "Deferred[None]":
         """Returns a deferred that resolves when finished writing to file"""
+        assert self._finished_deferred is not None
         return make_deferred_yieldable(self._finished_deferred)
 
     def _resume_paused_producer(self) -> None:
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 3aa9ba3c43..a2dfa1ed05 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -31,6 +31,7 @@ from synapse.types import Requester
 
 from tests import unittest
 from tests.test_utils import simple_async_mock
+from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
 
@@ -210,6 +211,69 @@ class AuthTestCase(unittest.HomeserverTestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         self.get_failure(self.auth.get_user_by_req(request), AuthError)
 
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with the ID of a valid device for the user in question,
+        the requester instance tracks that device ID.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"DOPPELDEVICE"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a truth-y value
+        self.store.get_device = simple_async_mock({"hidden": False})
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        requester = self.get_success(self.auth.get_user_by_req(request))
+        self.assertEquals(
+            requester.user.to_string(), masquerading_user_id.decode("utf8")
+        )
+        self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8"))
+
+    @override_config({"experimental_features": {"msc3202_device_masquerading": True}})
+    def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
+        """
+        Tests that when an application service passes the device_id URL parameter
+        with an ID that is not a valid device ID for the user in question,
+        the request fails with the appropriate error code.
+        """
+        masquerading_user_id = b"@doppelganger:matrix.org"
+        masquerading_device_id = b"NOT_A_REAL_DEVICE_ID"
+        app_service = Mock(
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
+        )
+        app_service.is_interested_in_user = Mock(return_value=True)
+        self.store.get_app_service_by_token = Mock(return_value=app_service)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
+        self.store.get_user_by_access_token = simple_async_mock(None)
+        # This also needs to just return a falsey value
+        self.store.get_device = simple_async_mock(None)
+
+        request = Mock(args={})
+        request.getClientIP.return_value = "127.0.0.1"
+        request.args[b"access_token"] = [self.test_token]
+        request.args[b"user_id"] = [masquerading_user_id]
+        request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+
+        failure = self.get_failure(self.auth.get_user_by_req(request), AuthError)
+        self.assertEquals(failure.value.code, 400)
+        self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE)
+
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = simple_async_mock(
             TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index f386b5e128..ba2a2bfd64 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -16,13 +16,13 @@ from unittest.mock import Mock
 
 from twisted.internet import defer
 
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, Namespace
 
 from tests import unittest
 
 
-def _regex(regex, exclusive=True):
-    return {"regex": re.compile(regex), "exclusive": exclusive}
+def _regex(regex: str, exclusive: bool = True) -> Namespace:
+    return Namespace(exclusive, None, re.compile(regex))
 
 
 class ApplicationServiceTestCase(unittest.TestCase):
@@ -33,11 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
             url="some_url",
             token="some_token",
             hostname="matrix.org",  # only used by get_groups_for_user
-            namespaces={
-                ApplicationService.NS_USERS: [],
-                ApplicationService.NS_ROOMS: [],
-                ApplicationService.NS_ALIASES: [],
-            },
         )
         self.event = Mock(
             type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index b457dad6d2..b2376e2db9 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -266,7 +266,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         )
 
         # expect signing key update edu
-        self.assertEqual(len(self.edus), 1)
+        self.assertEqual(len(self.edus), 2)
+        self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
         self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
 
         # sign the devices
@@ -491,7 +492,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
     ) -> None:
         """Check that the txn has an EDU with a signing key update."""
         edus = txn["edus"]
-        self.assertEqual(len(edus), 1)
+        self.assertEqual(len(edus), 2)
 
     def generate_and_upload_device_signing_key(
         self, user_id: str, device_id: str
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 663960ff53..bfa156eebb 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -108,6 +108,15 @@ class KnockingStrippedStateEventHelperMixin(TestCase):
                         "state_key": "",
                     },
                 ),
+                (
+                    EventTypes.Topic,
+                    {
+                        "content": {
+                            "topic": "A really cool room",
+                        },
+                        "state_key": "",
+                    },
+                ),
             ]
         )
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index f0723892e4..ddcf3ee348 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -161,8 +161,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
     def test_fallback_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
-        fallback_key = {"alg1:k1": "key1"}
-        fallback_key2 = {"alg1:k2": "key2"}
+        fallback_key = {"alg1:k1": "fallback_key1"}
+        fallback_key2 = {"alg1:k2": "fallback_key2"}
+        fallback_key3 = {"alg1:k2": "fallback_key3"}
         otk = {"alg1:k2": "key2"}
 
         # we shouldn't have any unused fallback keys yet
@@ -175,7 +176,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
-                {"org.matrix.msc2732.fallback_keys": fallback_key},
+                {"fallback_keys": fallback_key},
             )
         )
 
@@ -220,7 +221,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
-                {"org.matrix.msc2732.fallback_keys": fallback_key},
+                {"fallback_keys": fallback_key},
             )
         )
 
@@ -234,7 +235,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
-                {"org.matrix.msc2732.fallback_keys": fallback_key2},
+                {"fallback_keys": fallback_key2},
             )
         )
 
@@ -271,6 +272,25 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
         )
 
+        # using the unstable prefix should also set the fallback key
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key3},
+            )
+        )
+
+        res = self.get_success(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
+        )
+
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e1557566e4..496b581726 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -373,9 +373,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             destination: str, room_id: str, event_id: str
         ) -> List[EventBase]:
             return [
-                event_from_pdu_json(
-                    ae.get_pdu_json(), room_version=room_version, outlier=True
-                )
+                event_from_pdu_json(ae.get_pdu_json(), room_version=room_version)
                 for ae in auth_events
             ]
 
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 8a8d369fac..5816295d8b 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -23,6 +23,7 @@ from synapse.types import create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
+from tests.test_utils.event_injection import create_event
 
 logger = logging.getLogger(__name__)
 
@@ -51,6 +52,24 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         self.requester = create_requester(self.user_id, access_token_id=self.token_id)
 
+    def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]:
+        # Create a member event we can use as an auth_event
+        memberEvent, memberEventContext = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                type="m.room.member",
+                sender=self.requester.user.to_string(),
+                state_key=self.requester.user.to_string(),
+                content={"membership": "join"},
+            )
+        )
+        self.get_success(
+            self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+        )
+
+        return memberEvent, memberEventContext
+
     def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
         """Create a new event with the given transaction ID. All events produced
         by this method will be considered duplicates.
@@ -156,6 +175,90 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(events), 2)
         self.assertEqual(events[0].event_id, events[1].event_id)
 
+    def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
+        """When we set allow_no_prev_events=True, should be able to create a
+        event without any prev_events (only auth_events).
+        """
+        # Create a member event we can use as an auth_event
+        memberEvent, _ = self._create_and_persist_member_event()
+
+        # Try to create the event with empty prev_events bit with some auth_events
+        event, _ = self.get_success(
+            self.handler.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.requester.user.to_string(),
+                    "content": {"msgtype": "m.text", "body": random_string(5)},
+                },
+                # Empty prev_events is the key thing we're testing here
+                prev_event_ids=[],
+                # But with some auth_events
+                auth_event_ids=[memberEvent.event_id],
+                # Allow no prev_events!
+                allow_no_prev_events=True,
+            )
+        )
+        self.assertIsNotNone(event)
+
+    def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
+        self,
+    ):
+        """When we set allow_no_prev_events=False, shouldn't be able to create a
+        event without any prev_events even if it has auth_events. Expect an
+        exception to be raised.
+        """
+        # Create a member event we can use as an auth_event
+        memberEvent, _ = self._create_and_persist_member_event()
+
+        # Try to create the event with empty prev_events but with some auth_events
+        self.get_failure(
+            self.handler.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.requester.user.to_string(),
+                    "content": {"msgtype": "m.text", "body": random_string(5)},
+                },
+                # Empty prev_events is the key thing we're testing here
+                prev_event_ids=[],
+                # But with some auth_events
+                auth_event_ids=[memberEvent.event_id],
+                # We expect the test to fail because empty prev_events are not
+                # allowed here!
+                allow_no_prev_events=False,
+            ),
+            AssertionError,
+        )
+
+    def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
+        self,
+    ):
+        """When we set allow_no_prev_events=True, should be able to create a
+        event without any prev_events or auth_events. Expect an exception to be
+        raised.
+        """
+        # Try to create the event with empty prev_events and empty auth_events
+        self.get_failure(
+            self.handler.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.requester.user.to_string(),
+                    "content": {"msgtype": "m.text", "body": random_string(5)},
+                },
+                prev_event_ids=[],
+                # The event should be rejected when there are no auth_events
+                auth_event_ids=[],
+                # Allow no prev_events!
+                allow_no_prev_events=True,
+            ),
+            AssertionError,
+        )
+
 
 class ServerAclValidationTestCase(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index b25a06b427..eca6a443af 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,6 +20,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
 from synapse.handlers.room import RoomEventSource
 from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.storage.databases.main.event_push_actions import NotifCounts
 from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
 from synapse.types import PersistedEventPosition
 
@@ -166,7 +167,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
+            NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
         )
 
         self.persist(
@@ -179,7 +180,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
+            NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
         )
 
         self.persist(
@@ -194,7 +195,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
+            NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
         )
 
     def test_get_rooms_for_user_with_stream_ordering(self):
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 04a869e295..1b6a4bf4b0 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -62,7 +62,11 @@ class FederationAckTestCase(HomeserverTestCase):
                 "federation",
                 "master",
                 token=10,
-                rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])],
+                rows=[
+                    FederationStream.FederationStreamRow(
+                        type="x", data={"test": [1, 2, 3]}
+                    )
+                ],
             )
         )
 
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index 4d152c0d66..1e3fe9c62c 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -23,6 +23,7 @@ from synapse.api.errors import Codes
 from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -96,7 +97,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
     def _register_bg_update(self) -> None:
         "Adds a bg update but doesn't start it"
 
-        async def _fake_update(progress, batch_size) -> int:
+        async def _fake_update(progress: JsonDict, batch_size: int) -> int:
             await self.clock.sleep(0.2)
             return batch_size
 
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index 5188499ef2..742f194257 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -16,11 +16,14 @@ from typing import List, Optional
 
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.errors import Codes
 from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -31,7 +34,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastore()
         self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
@@ -44,7 +47,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             ("/_synapse/admin/v1/federation/destinations/dummy",),
         ]
     )
-    def test_requester_is_no_admin(self, url: str):
+    def test_requester_is_no_admin(self, url: str) -> None:
         """
         If the user is not a server admin, an error 403 is returned.
         """
@@ -62,7 +65,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_invalid_parameter(self):
+    def test_invalid_parameter(self) -> None:
         """
         If parameters are invalid, an error is returned.
         """
@@ -95,7 +98,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
         channel = self.make_request(
@@ -105,7 +108,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid destination
         channel = self.make_request(
@@ -117,7 +120,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-    def test_limit(self):
+    def test_limit(self) -> None:
         """
         Testing list of destinations with limit
         """
@@ -137,7 +140,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["next_token"], "5")
         self._check_fields(channel.json_body["destinations"])
 
-    def test_from(self):
+    def test_from(self) -> None:
         """
         Testing list of destinations with a defined starting point (from)
         """
@@ -157,7 +160,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("next_token", channel.json_body)
         self._check_fields(channel.json_body["destinations"])
 
-    def test_limit_and_from(self):
+    def test_limit_and_from(self) -> None:
         """
         Testing list of destinations with a defined starting point and limit
         """
@@ -177,7 +180,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(channel.json_body["destinations"]), 10)
         self._check_fields(channel.json_body["destinations"])
 
-    def test_next_token(self):
+    def test_next_token(self) -> None:
         """
         Testing that `next_token` appears at the right place
         """
@@ -238,7 +241,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(channel.json_body["destinations"]), 1)
         self.assertNotIn("next_token", channel.json_body)
 
-    def test_list_all_destinations(self):
+    def test_list_all_destinations(self) -> None:
         """
         List all destinations.
         """
@@ -259,7 +262,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # Check that all fields are available
         self._check_fields(channel.json_body["destinations"])
 
-    def test_order_by(self):
+    def test_order_by(self) -> None:
         """
         Testing order list with parameter `order_by`
         """
@@ -268,7 +271,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             expected_destination_list: List[str],
             order_by: Optional[str],
             dir: Optional[str] = None,
-        ):
+        ) -> None:
             """Request the list of destinations in a certain order.
             Assert that order is what we expect
 
@@ -358,13 +361,13 @@ class FederationTestCase(unittest.HomeserverTestCase):
             [dest[0][0], dest[2][0], dest[1][0]], "last_successful_stream_ordering", "b"
         )
 
-    def test_search_term(self):
+    def test_search_term(self) -> None:
         """Test that searching for a destination works correctly"""
 
         def _search_test(
             expected_destination: Optional[str],
             search_term: str,
-        ):
+        ) -> None:
             """Search for a destination and check that the returned destinationis a match
 
             Args:
@@ -410,7 +413,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         _search_test(None, "foo")
         _search_test(None, "bar")
 
-    def test_get_single_destination(self):
+    def test_get_single_destination(self) -> None:
         """
         Get one specific destinations.
         """
@@ -429,7 +432,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # convert channel.json_body into a List
         self._check_fields([channel.json_body])
 
-    def _create_destinations(self, number_destinations: int):
+    def _create_destinations(self, number_destinations: int) -> None:
         """Create a number of destinations
 
         Args:
@@ -442,7 +445,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 self.store.set_destination_last_successful_stream_ordering(dest, 100)
             )
 
-    def _check_fields(self, content: List[JsonDict]):
+    def _check_fields(self, content: List[JsonDict]) -> None:
         """Checks that the expected destination attributes are present in content
 
         Args:
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 81e578fd26..86aff7575c 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -360,7 +360,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.code,
             msg=channel.json_body,
         )
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
         self.assertEqual(
             "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
             channel.json_body["error"],
@@ -580,7 +580,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
         return server_and_media_id
 
-    def _access_media(self, server_and_media_id, expect_success=True) -> None:
+    def _access_media(
+        self, server_and_media_id: str, expect_success: bool = True
+    ) -> None:
         """
         Try to access a media and check the result
         """
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
index 350a62dda6..81f3ac7f04 100644
--- a/tests/rest/admin/test_registration_tokens.py
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -14,6 +14,7 @@
 import random
 import string
 from http import HTTPStatus
+from typing import Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -42,21 +43,27 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
 
         self.url = "/_synapse/admin/v1/registration_tokens"
 
-    def _new_token(self, **kwargs) -> str:
+    def _new_token(
+        self,
+        token: Optional[str] = None,
+        uses_allowed: Optional[int] = None,
+        pending: int = 0,
+        completed: int = 0,
+        expiry_time: Optional[int] = None,
+    ) -> str:
         """Helper function to create a token."""
-        token = kwargs.get(
-            "token",
-            "".join(random.choices(string.ascii_letters, k=8)),
-        )
+        if token is None:
+            token = "".join(random.choices(string.ascii_letters, k=8))
+
         self.get_success(
             self.store.db_pool.simple_insert(
                 "registration_tokens",
                 {
                     "token": token,
-                    "uses_allowed": kwargs.get("uses_allowed", None),
-                    "pending": kwargs.get("pending", 0),
-                    "completed": kwargs.get("completed", 0),
-                    "expiry_time": kwargs.get("expiry_time", None),
+                    "uses_allowed": uses_allowed,
+                    "pending": pending,
+                    "completed": completed,
+                    "expiry_time": expiry_time,
                 },
             )
         )
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 22f9aa6234..d2c8781cd4 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -66,7 +66,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         )
         self.url = "/_synapse/admin/v1/rooms/%s" % self.room_id
 
-    def test_requester_is_no_admin(self):
+    def test_requester_is_no_admin(self) -> None:
         """
         If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
         """
@@ -81,7 +81,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_room_does_not_exist(self):
+    def test_room_does_not_exist(self) -> None:
         """
         Check that unknown rooms/server return 200
         """
@@ -96,7 +96,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
 
-    def test_room_is_not_valid(self):
+    def test_room_is_not_valid(self) -> None:
         """
         Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
         """
@@ -115,7 +115,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_new_room_user_does_not_exist(self):
+    def test_new_room_user_does_not_exist(self) -> None:
         """
         Tests that the user ID must be from local server but it does not have to exist.
         """
@@ -133,7 +133,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertIn("failed_to_kick_users", channel.json_body)
         self.assertIn("local_aliases", channel.json_body)
 
-    def test_new_room_user_is_not_local(self):
+    def test_new_room_user_is_not_local(self) -> None:
         """
         Check that only local users can create new room to move members.
         """
@@ -151,7 +151,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_block_is_not_bool(self):
+    def test_block_is_not_bool(self) -> None:
         """
         If parameter `block` is not boolean, return an error
         """
@@ -166,7 +166,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_purge_is_not_bool(self):
+    def test_purge_is_not_bool(self) -> None:
         """
         If parameter `purge` is not boolean, return an error
         """
@@ -181,7 +181,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_purge_room_and_block(self):
+    def test_purge_room_and_block(self) -> None:
         """Test to purge a room and block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -212,7 +212,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
-    def test_purge_room_and_not_block(self):
+    def test_purge_room_and_not_block(self) -> None:
         """Test to purge a room and do not block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -243,7 +243,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=False)
         self._has_no_members(self.room_id)
 
-    def test_block_room_and_not_purge(self):
+    def test_block_room_and_not_purge(self) -> None:
         """Test to block a room without purging it.
         Members will not be moved to a new room and will not receive a message.
         The room will not be purged.
@@ -299,7 +299,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self._is_blocked(room_id)
 
-    def test_shutdown_room_consent(self):
+    def test_shutdown_room_consent(self) -> None:
         """Test that we can shutdown rooms with local users who have not
         yet accepted the privacy policy. This used to fail when we tried to
         force part the user from the old room.
@@ -351,7 +351,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         self._is_purged(self.room_id)
         self._has_no_members(self.room_id)
 
-    def test_shutdown_room_block_peek(self):
+    def test_shutdown_room_block_peek(self) -> None:
         """Test that a world_readable room can no longer be peeked into after
         it has been shut down.
         Members will be moved to a new room and will receive a message.
@@ -400,7 +400,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         # Assert we can no longer peek into the room
         self._assert_peek(self.room_id, expect_code=HTTPStatus.FORBIDDEN)
 
-    def _is_blocked(self, room_id, expect=True):
+    def _is_blocked(self, room_id: str, expect: bool = True) -> None:
         """Assert that the room is blocked or not"""
         d = self.store.is_room_blocked(room_id)
         if expect:
@@ -408,17 +408,17 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         else:
             self.assertIsNone(self.get_success(d))
 
-    def _has_no_members(self, room_id):
+    def _has_no_members(self, room_id: str) -> None:
         """Assert there is now no longer anyone in the room"""
         users_in_room = self.get_success(self.store.get_users_in_room(room_id))
         self.assertEqual([], users_in_room)
 
-    def _is_member(self, room_id, user_id):
+    def _is_member(self, room_id: str, user_id: str) -> None:
         """Test that user is member of the room"""
         users_in_room = self.get_success(self.store.get_users_in_room(room_id))
         self.assertIn(user_id, users_in_room)
 
-    def _is_purged(self, room_id):
+    def _is_purged(self, room_id: str) -> None:
         """Test that the following tables have been purged of all rows related to the room."""
         for table in PURGE_TABLES:
             count = self.get_success(
@@ -432,7 +432,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
 
             self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
 
-    def _assert_peek(self, room_id, expect_code):
+    def _assert_peek(self, room_id: str, expect_code: int) -> None:
         """Assert that the admin user can (or cannot) peek into the room."""
 
         url = "rooms/%s/initialSync" % (room_id,)
@@ -492,7 +492,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             ("GET", "/_synapse/admin/v2/rooms/delete_status/%s"),
         ]
     )
-    def test_requester_is_no_admin(self, method: str, url: str):
+    def test_requester_is_no_admin(self, method: str, url: str) -> None:
         """
         If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
         """
@@ -507,7 +507,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_room_does_not_exist(self):
+    def test_room_does_not_exist(self) -> None:
         """
         Check that unknown rooms/server return 200
 
@@ -544,7 +544,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             ("GET", "/_synapse/admin/v2/rooms/%s/delete_status"),
         ]
     )
-    def test_room_is_not_valid(self, method: str, url: str):
+    def test_room_is_not_valid(self, method: str, url: str) -> None:
         """
         Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
         """
@@ -562,7 +562,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_new_room_user_does_not_exist(self):
+    def test_new_room_user_does_not_exist(self) -> None:
         """
         Tests that the user ID must be from local server but it does not have to exist.
         """
@@ -580,7 +580,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
 
         self._test_result(delete_id, self.other_user, expect_new_room=True)
 
-    def test_new_room_user_is_not_local(self):
+    def test_new_room_user_is_not_local(self) -> None:
         """
         Check that only local users can create new room to move members.
         """
@@ -598,7 +598,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_block_is_not_bool(self):
+    def test_block_is_not_bool(self) -> None:
         """
         If parameter `block` is not boolean, return an error
         """
@@ -613,7 +613,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_purge_is_not_bool(self):
+    def test_purge_is_not_bool(self) -> None:
         """
         If parameter `purge` is not boolean, return an error
         """
@@ -628,7 +628,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
 
-    def test_delete_expired_status(self):
+    def test_delete_expired_status(self) -> None:
         """Test that the task status is removed after expiration."""
 
         # first task, do not purge, that we can create a second task
@@ -699,7 +699,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-    def test_delete_same_room_twice(self):
+    def test_delete_same_room_twice(self) -> None:
         """Test that the call for delete a room at second time gives an exception."""
 
         body = {"new_room_user_id": self.admin_user}
@@ -743,7 +743,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
             expect_new_room=True,
         )
 
-    def test_purge_room_and_block(self):
+    def test_purge_room_and_block(self) -> None:
         """Test to purge a room and block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -774,7 +774,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
-    def test_purge_room_and_not_block(self):
+    def test_purge_room_and_not_block(self) -> None:
         """Test to purge a room and do not block it.
         Members will not be moved to a new room and will not receive a message.
         """
@@ -805,7 +805,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=False)
         self._has_no_members(self.room_id)
 
-    def test_block_room_and_not_purge(self):
+    def test_block_room_and_not_purge(self) -> None:
         """Test to block a room without purging it.
         Members will not be moved to a new room and will not receive a message.
         The room will not be purged.
@@ -838,7 +838,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_blocked(self.room_id, expect=True)
         self._has_no_members(self.room_id)
 
-    def test_shutdown_room_consent(self):
+    def test_shutdown_room_consent(self) -> None:
         """Test that we can shutdown rooms with local users who have not
         yet accepted the privacy policy. This used to fail when we tried to
         force part the user from the old room.
@@ -899,7 +899,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
         self._is_purged(self.room_id)
         self._has_no_members(self.room_id)
 
-    def test_shutdown_room_block_peek(self):
+    def test_shutdown_room_block_peek(self) -> None:
         """Test that a world_readable room can no longer be peeked into after
         it has been shut down.
         Members will be moved to a new room and will receive a message.
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 4fedd5fd08..e0b9fe8e91 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -608,7 +608,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid deactivated
         channel = self.make_request(
@@ -618,7 +618,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # unkown order_by
         channel = self.make_request(
@@ -628,7 +628,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
         channel = self.make_request(
@@ -638,7 +638,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
     def test_limit(self):
         """
@@ -1550,7 +1550,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Create user
         body = {
             "password": "abc123",
-            "threepids": [{"medium": "email", "address": "bob@bob.bob"}],
+            # Note that the given email is not in canonical form.
+            "threepids": [{"medium": "email", "address": "Bob@bob.bob"}],
         }
 
         channel = self.make_request(
@@ -2896,7 +2897,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # invalid search order
         channel = self.make_request(
@@ -2906,7 +2907,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
 
         # negative limit
         channel = self.make_request(
@@ -3882,3 +3883,93 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
         self.assertNotIn("messages_per_second", channel.json_body)
         self.assertNotIn("burst_count", channel.json_body)
+
+
+class AccountDataTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs) -> None:
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.url = f"/_synapse/admin/v1/users/{self.other_user}/accountdata"
+
+    def test_no_auth(self) -> None:
+        """Try to get information of a user without authentication."""
+        channel = self.make_request("GET", self.url, {})
+
+        self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self) -> None:
+        """If the user is not a server admin, an error is returned."""
+        other_user_token = self.login("user", "pass")
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=other_user_token,
+        )
+
+        self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_user_does_not_exist(self) -> None:
+        """Tests that a lookup for a user that does not exist returns a 404"""
+        url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
+
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_user_is_not_local(self) -> None:
+        """Tests that a lookup for a user that is not a local returns a 400"""
+        url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/accountdata"
+
+        channel = self.make_request(
+            "GET",
+            url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
+        self.assertEqual("Can only look up local users", channel.json_body["error"])
+
+    def test_success(self) -> None:
+        """Request account data should succeed for an admin."""
+
+        # add account data
+        self.get_success(
+            self.store.add_account_data_for_user(self.other_user, "m.global", {"a": 1})
+        )
+        self.get_success(
+            self.store.add_account_data_to_room(
+                self.other_user, "test_room", "m.per_room", {"b": 2}
+            )
+        )
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            {"a": 1}, channel.json_body["account_data"]["global"]["m.global"]
+        )
+        self.assertEqual(
+            {"b": 2},
+            channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"],
+        )
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 72bbc87b4a..27cb856b0a 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -85,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", "auth/m.login.recaptcha/fallback/web?session=" + session
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
 
         channel = self.make_request(
             "POST",
@@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         """Ensure that fallback auth via a captcha works."""
         # Returns a 401 as per the spec
         channel = self.register(
-            401,
+            HTTPStatus.UNAUTHORIZED,
             {"username": "user", "type": "m.login.password", "password": "bar"},
         )
 
@@ -116,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         )
 
         # Complete the recaptcha step.
-        self.recaptcha(session, 200)
+        self.recaptcha(session, HTTPStatus.OK)
 
         # also complete the dummy auth
-        self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}})
+        self.register(
+            HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}}
+        )
 
         # Now we should have fulfilled a complete auth flow, including
         # the recaptcha fallback step, we can then send a
         # request to the register API with the session in the authdict.
-        channel = self.register(200, {"auth": {"session": session}})
+        channel = self.register(HTTPStatus.OK, {"auth": {"session": session}})
 
         # We're given a registered user.
         self.assertEqual(channel.json_body["user_id"], "@user:test")
@@ -137,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         # will be used.)
         # Returns a 401 as per the spec
         channel = self.register(
-            401, {"username": "user", "type": "m.login.password", "password": "bar"}
+            HTTPStatus.UNAUTHORIZED,
+            {"username": "user", "type": "m.login.password", "password": "bar"},
         )
 
         # Grab the session
@@ -231,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """
         # Attempt to delete this device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         # Grab the session
         session = channel.json_body["session"]
@@ -242,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_device(
             self.user_tok,
             self.device_id,
-            200,
+            HTTPStatus.OK,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -260,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase):
         UIA - check that still works.
         """
 
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
         session = channel.json_body["session"]
 
         # Make another request providing the UI auth flow.
         self.delete_device(
             self.user_tok,
             self.device_id,
-            200,
+            HTTPStatus.OK,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -293,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_devices(401, {"devices": [self.device_id]})
+        channel = self.delete_devices(
+            HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]}
+        )
 
         # Grab the session
         session = channel.json_body["session"]
@@ -303,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Make another request providing the UI auth flow, but try to delete the
         # second device.
         self.delete_devices(
-            200,
+            HTTPStatus.OK,
             {
                 "devices": ["dev2"],
                 "auth": {
@@ -324,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         # Grab the session
         session = channel.json_body["session"]
@@ -338,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_device(
             self.user_tok,
             "dev2",
-            403,
+            HTTPStatus.FORBIDDEN,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -361,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.login("test", self.user_pass, "dev3")
 
         # Attempt to delete a device. This works since the user just logged in.
-        self.delete_device(self.user_tok, "dev2", 200)
+        self.delete_device(self.user_tok, "dev2", HTTPStatus.OK)
 
         # Move the clock forward past the validation timeout.
         self.reactor.advance(6)
 
         # Deleting another devices throws the user into UI auth.
-        channel = self.delete_device(self.user_tok, "dev3", 401)
+        channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -378,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.delete_device(
             self.user_tok,
             "dev3",
-            200,
+            HTTPStatus.OK,
             {
                 "auth": {
                     "type": "m.login.password",
@@ -393,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # due to re-using the previous session.
         #
         # Note that *no auth* information is provided, not even a session iD!
-        self.delete_device(self.user_tok, self.device_id, 200)
+        self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK)
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
@@ -413,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(login_resp["user_id"], self.user)
 
         # initiate a UI Auth process by attempting to delete the device
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         # check that SSO is offered
         flows = channel.json_body["flows"]
@@ -426,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         )
 
         # that should serve a confirmation page
-        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
 
         # and now the delete request should succeed.
         self.delete_device(
             self.user_tok,
             self.device_id,
-            200,
+            HTTPStatus.OK,
             body={"auth": {"session": session_id}},
         )
 
@@ -445,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # now call the device deletion API: we should get the option to auth with SSO
         # and not password.
-        channel = self.delete_device(user_tok, device_id, 401)
+        channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED)
 
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
 
     def test_does_not_offer_sso_for_password_user(self):
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         flows = channel.json_body["flows"]
         self.assertEqual(flows, [{"stages": ["m.login.password"]}])
@@ -463,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
         self.assertEqual(login_resp["user_id"], self.user)
 
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         flows = channel.json_body["flows"]
         # we have no particular expectations of ordering here
@@ -480,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(login_resp["user_id"], self.user)
 
         # start a UI Auth flow by attempting to delete a device
-        channel = self.delete_device(self.user_tok, self.device_id, 401)
+        channel = self.delete_device(
+            self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
+        )
 
         flows = channel.json_body["flows"]
         self.assertIn({"stages": ["m.login.sso"]}, flows)
@@ -496,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # ... and the delete op should now fail with a 403
         self.delete_device(
-            self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
+            self.user_tok,
+            self.device_id,
+            HTTPStatus.FORBIDDEN,
+            body={"auth": {"session": session_id}},
         )
 
 
@@ -551,7 +573,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         login_without_refresh = self.make_request(
             "POST", "/_matrix/client/r0/login", body
         )
-        self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result)
+        self.assertEqual(
+            login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result
+        )
         self.assertNotIn("refresh_token", login_without_refresh.json_body)
 
         login_with_refresh = self.make_request(
@@ -559,7 +583,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             {"refresh_token": True, **body},
         )
-        self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
+        self.assertEqual(
+            login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result
+        )
         self.assertIn("refresh_token", login_with_refresh.json_body)
         self.assertIn("expires_in_ms", login_with_refresh.json_body)
 
@@ -577,7 +603,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             },
         )
         self.assertEqual(
-            register_without_refresh.code, 200, register_without_refresh.result
+            register_without_refresh.code,
+            HTTPStatus.OK,
+            register_without_refresh.result,
         )
         self.assertNotIn("refresh_token", register_without_refresh.json_body)
 
@@ -591,7 +619,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
                 "refresh_token": True,
             },
         )
-        self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
+        self.assertEqual(
+            register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result
+        )
         self.assertIn("refresh_token", register_with_refresh.json_body)
         self.assertIn("expires_in_ms", register_with_refresh.json_body)
 
@@ -610,14 +640,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
 
         refresh_response = self.make_request(
             "POST",
             "/_matrix/client/v1/refresh",
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
-        self.assertEqual(refresh_response.code, 200, refresh_response.result)
+        self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
         self.assertIn("access_token", refresh_response.json_body)
         self.assertIn("refresh_token", refresh_response.json_body)
         self.assertIn("expires_in_ms", refresh_response.json_body)
@@ -648,7 +678,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
         self.assertApproximates(
             login_response.json_body["expires_in_ms"], 60 * 1000, 100
         )
@@ -658,7 +688,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/v1/refresh",
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
-        self.assertEqual(refresh_response.code, 200, refresh_response.result)
+        self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
         self.assertApproximates(
             refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
         )
@@ -705,7 +735,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             {"refresh_token": True, **body},
         )
-        self.assertEqual(login_response1.code, 200, login_response1.result)
+        self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result)
         self.assertApproximates(
             login_response1.json_body["expires_in_ms"], 60 * 1000, 100
         )
@@ -716,7 +746,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response2.code, 200, login_response2.result)
+        self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result)
         nonrefreshable_access_token = login_response2.json_body["access_token"]
 
         # Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
@@ -818,7 +848,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
         refresh_token = login_response.json_body["refresh_token"]
 
         # Advance shy of 2 minutes into the future
@@ -826,7 +856,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
 
         # Refresh our session. The refresh token should still be valid right now.
         refresh_response = self.use_refresh_token(refresh_token)
-        self.assertEqual(refresh_response.code, 200, refresh_response.result)
+        self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
         self.assertIn(
             "refresh_token",
             refresh_response.json_body,
@@ -846,7 +876,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         # This should fail because the refresh token's lifetime has also been
         # diminished as our session expired.
         refresh_response = self.use_refresh_token(refresh_token)
-        self.assertEqual(refresh_response.code, 403, refresh_response.result)
+        self.assertEqual(
+            refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
+        )
 
     def test_refresh_token_invalidation(self):
         """Refresh tokens are invalidated after first use of the next token.
@@ -875,7 +907,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/login",
             body,
         )
-        self.assertEqual(login_response.code, 200, login_response.result)
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
 
         # This first refresh should work properly
         first_refresh_response = self.make_request(
@@ -884,7 +916,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            first_refresh_response.code, 200, first_refresh_response.result
+            first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result
         )
 
         # This one as well, since the token in the first one was never used
@@ -894,7 +926,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            second_refresh_response.code, 200, second_refresh_response.result
+            second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result
         )
 
         # This one should not, since the token from the first refresh is not valid anymore
@@ -904,7 +936,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": first_refresh_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            third_refresh_response.code, 401, third_refresh_response.result
+            third_refresh_response.code,
+            HTTPStatus.UNAUTHORIZED,
+            third_refresh_response.result,
         )
 
         # The associated access token should also be invalid
@@ -913,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             "/_matrix/client/r0/account/whoami",
             access_token=first_refresh_response.json_body["access_token"],
         )
-        self.assertEqual(whoami_response.code, 401, whoami_response.result)
+        self.assertEqual(
+            whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result
+        )
 
         # But all other tokens should work (they will expire after some time)
         for access_token in [
@@ -923,7 +959,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             whoami_response = self.make_request(
                 "GET", "/_matrix/client/r0/account/whoami", access_token=access_token
             )
-            self.assertEqual(whoami_response.code, 200, whoami_response.result)
+            self.assertEqual(
+                whoami_response.code, HTTPStatus.OK, whoami_response.result
+            )
 
         # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
         fourth_refresh_response = self.make_request(
@@ -932,7 +970,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": login_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            fourth_refresh_response.code, 403, fourth_refresh_response.result
+            fourth_refresh_response.code,
+            HTTPStatus.FORBIDDEN,
+            fourth_refresh_response.result,
         )
 
         # But refreshing from the last valid refresh token still works
@@ -942,5 +982,5 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": second_refresh_response.json_body["refresh_token"]},
         )
         self.assertEqual(
-            fifth_refresh_response.code, 200, fifth_refresh_response.result
+            fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
         )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 397c12c2a6..c026d526ef 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -16,6 +16,7 @@
 import itertools
 import urllib.parse
 from typing import Dict, List, Optional, Tuple
+from unittest.mock import patch
 
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
@@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync
 
 from tests import unittest
 from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
+from tests.test_utils.event_injection import inject_event
 
 
 class RelationsTestCase(unittest.HomeserverTestCase):
@@ -574,11 +577,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
 
         # Request sync.
-        channel = self.make_request("GET", "/sync", access_token=self.user_token)
-        self.assertEquals(200, channel.code, channel.json_body)
-        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        self.assertTrue(room_timeline["limited"])
-        _find_and_assert_event(room_timeline["events"])
+        # channel = self.make_request("GET", "/sync", access_token=self.user_token)
+        # self.assertEquals(200, channel.code, channel.json_body)
+        # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        # self.assertTrue(room_timeline["limited"])
+        # _find_and_assert_event(room_timeline["events"])
 
         # Note that /relations is tested separately in test_aggregation_get_event_for_thread
         # since it needs different data configured.
@@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             },
         )
 
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_ignore_invalid_room(self):
+        """Test that we ignore invalid relations over federation."""
+        # Create another room and send a message in it.
+        room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+        res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+        parent_id = res["event_id"]
+
+        # Disable the validation to pretend this came over federation.
+        with patch(
+            "synapse.handlers.message.EventCreationHandler._validate_event_relation",
+            new=lambda self, event: make_awaitable(None),
+        ):
+            # Generate a various relations from a different room.
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.reaction",
+                    sender=self.user_id,
+                    content={
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.ANNOTATION,
+                            "event_id": parent_id,
+                            "key": "A",
+                        }
+                    },
+                )
+            )
+
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.room.message",
+                    sender=self.user_id,
+                    content={
+                        "body": "foo",
+                        "msgtype": "m.text",
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.REFERENCE,
+                            "event_id": parent_id,
+                        },
+                    },
+                )
+            )
+
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.room.message",
+                    sender=self.user_id,
+                    content={
+                        "body": "foo",
+                        "msgtype": "m.text",
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.THREAD,
+                            "event_id": parent_id,
+                        },
+                    },
+                )
+            )
+
+            self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=self.room,
+                    type="m.room.message",
+                    sender=self.user_id,
+                    content={
+                        "body": "foo",
+                        "msgtype": "m.text",
+                        "new_content": {
+                            "body": "new content",
+                            "msgtype": "m.text",
+                        },
+                        "m.relates_to": {
+                            "rel_type": RelationTypes.REPLACE,
+                            "event_id": parent_id,
+                        },
+                    },
+                )
+            )
+
+        # They should be ignored when fetching relations.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(channel.json_body["chunk"], [])
+
+        # And when fetching aggregations.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEqual(channel.json_body["chunk"], [])
+
+        # And for bundled aggregations.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{room2}/event/{parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
     def test_edit(self):
         """Test that a simple edit works."""
 
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
new file mode 100644
index 0000000000..721454c187
--- /dev/null
+++ b/tests/rest/client/test_room_batch.py
@@ -0,0 +1,180 @@
+import logging
+from typing import List, Tuple
+from unittest.mock import Mock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.appservice import ApplicationService
+from synapse.rest import admin
+from synapse.rest.client import login, register, room, room_batch
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+def _create_join_state_events_for_batch_send_request(
+    virtual_user_ids: List[str],
+    insert_time: int,
+) -> List[JsonDict]:
+    return [
+        {
+            "type": EventTypes.Member,
+            "sender": virtual_user_id,
+            "origin_server_ts": insert_time,
+            "content": {
+                "membership": "join",
+                "displayname": "display-name-for-%s" % (virtual_user_id,),
+            },
+            "state_key": virtual_user_id,
+        }
+        for virtual_user_id in virtual_user_ids
+    ]
+
+
+def _create_message_events_for_batch_send_request(
+    virtual_user_id: str, insert_time: int, count: int
+) -> List[JsonDict]:
+    return [
+        {
+            "type": EventTypes.Message,
+            "sender": virtual_user_id,
+            "origin_server_ts": insert_time,
+            "content": {
+                "msgtype": "m.text",
+                "body": "Historical %d" % (i),
+                EventContentFields.MSC2716_HISTORICAL: True,
+            },
+        }
+        for i in range(count)
+    ]
+
+
+class RoomBatchTestCase(unittest.HomeserverTestCase):
+    """Test importing batches of historical messages."""
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room_batch.register_servlets,
+        room.register_servlets,
+        register.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+
+        self.appservice = ApplicationService(
+            token="i_am_an_app_service",
+            hostname="test",
+            id="1234",
+            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+
+        mock_load_appservices = Mock(return_value=[self.appservice])
+        with patch(
+            "synapse.storage.databases.main.appservice.load_appservices",
+            mock_load_appservices,
+        ):
+            hs = self.setup_test_homeserver(config=config)
+        return hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.clock = clock
+        self.storage = hs.get_storage()
+
+        self.virtual_user_id = self.register_appservice_user(
+            "as_user_potato", self.appservice.token
+        )
+
+    def _create_test_room(self) -> Tuple[str, str, str, str]:
+        room_id = self.helper.create_room_as(
+            self.appservice.sender, tok=self.appservice.token
+        )
+
+        res_a = self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "A",
+            },
+            tok=self.appservice.token,
+        )
+        event_id_a = res_a["event_id"]
+
+        res_b = self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "B",
+            },
+            tok=self.appservice.token,
+        )
+        event_id_b = res_b["event_id"]
+
+        res_c = self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "C",
+            },
+            tok=self.appservice.token,
+        )
+        event_id_c = res_c["event_id"]
+
+        return room_id, event_id_a, event_id_b, event_id_c
+
+    @unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
+    def test_same_state_groups_for_whole_historical_batch(self):
+        """Make sure that when using the `/batch_send` endpoint to import a
+        bunch of historical messages, it re-uses the same `state_group` across
+        the whole batch. This is an easy optimization to make sure we're getting
+        right because the state for the whole batch is contained in
+        `state_events_at_start` and can be shared across everything.
+        """
+
+        time_before_room = int(self.clock.time_msec())
+        room_id, event_id_a, _, _ = self._create_test_room()
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s"
+            % (room_id, event_id_a),
+            content={
+                "events": _create_message_events_for_batch_send_request(
+                    self.virtual_user_id, time_before_room, 3
+                ),
+                "state_events_at_start": _create_join_state_events_for_batch_send_request(
+                    [self.virtual_user_id], time_before_room
+                ),
+            },
+            access_token=self.appservice.token,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # Get the historical event IDs that we just imported
+        historical_event_ids = channel.json_body["event_ids"]
+        self.assertEqual(len(historical_event_ids), 3)
+
+        # Fetch the state_groups
+        state_group_map = self.get_success(
+            self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+        )
+
+        # We expect all of the historical events to be using the same state_group
+        # so there should only be a single state_group here!
+        self.assertEqual(
+            len(state_group_map.keys()),
+            1,
+            "Expected a single state_group to be returned by saw state_groups=%s"
+            % (state_group_map.keys(),),
+        )
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 8698135a76..16e904f15b 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -1,4 +1,5 @@
 # Copyright 2018 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tests/server.py b/tests/server.py
index 40cf5b12c3..ca2b7a5b97 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,9 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import hashlib
 import json
 import logging
+import time
+import uuid
+import warnings
 from collections import deque
 from io import SEEK_END, BytesIO
 from typing import (
@@ -27,6 +30,7 @@ from typing import (
     Type,
     Union,
 )
+from unittest.mock import Mock
 
 import attr
 from typing_extensions import Deque
@@ -53,11 +57,24 @@ from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.http.site import SynapseRequest
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.utils import setup_test_homeserver as _sth
+from tests.utils import (
+    LEAVE_DB,
+    POSTGRES_BASE_DB,
+    POSTGRES_HOST,
+    POSTGRES_PASSWORD,
+    POSTGRES_USER,
+    USE_POSTGRES_FOR_TESTS,
+    MockClock,
+    default_config,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -450,14 +467,11 @@ class ThreadPool:
         return d
 
 
-def setup_test_homeserver(cleanup_func, *args, **kwargs):
+def _make_test_homeserver_synchronous(server: HomeServer) -> None:
     """
-    Set up a synchronous test server, driven by the reactor used by
-    the homeserver.
+    Make the given test homeserver's database interactions synchronous.
     """
-    server = _sth(cleanup_func, *args, **kwargs)
 
-    # Make the thread pool synchronous.
     clock = server.get_clock()
 
     for database in server.get_datastores().databases:
@@ -485,6 +499,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
 
         pool.runWithConnection = runWithConnection
         pool.runInteraction = runInteraction
+        # Replace the thread pool with a threadless 'thread' pool
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
 
@@ -492,8 +507,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     # thread, so we need to disable the dedicated thread behaviour.
     server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
 
-    return server
-
 
 def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
     clock = ThreadedMemoryReactorClock()
@@ -673,3 +686,171 @@ def connect_client(
     client.makeConnection(FakeTransport(server, reactor))
 
     return client, server
+
+
+class TestHomeServer(HomeServer):
+    DATASTORE_CLASS = DataStore
+
+
+def setup_test_homeserver(
+    cleanup_func,
+    name="test",
+    config=None,
+    reactor=None,
+    homeserver_to_use: Type[HomeServer] = TestHomeServer,
+    **kwargs,
+):
+    """
+    Setup a homeserver suitable for running tests against.  Keyword arguments
+    are passed to the Homeserver constructor.
+
+    If no datastore is supplied, one is created and given to the homeserver.
+
+    Args:
+        cleanup_func : The function used to register a cleanup routine for
+                       after the test.
+
+    Calling this method directly is deprecated: you should instead derive from
+    HomeserverTestCase.
+    """
+    if reactor is None:
+        from twisted.internet import reactor
+
+    if config is None:
+        config = default_config(name, parse=True)
+
+    config.ldap_enabled = False
+
+    if "clock" not in kwargs:
+        kwargs["clock"] = MockClock()
+
+    if USE_POSTGRES_FOR_TESTS:
+        test_db = "synapse_test_%s" % uuid.uuid4().hex
+
+        database_config = {
+            "name": "psycopg2",
+            "args": {
+                "database": test_db,
+                "host": POSTGRES_HOST,
+                "password": POSTGRES_PASSWORD,
+                "user": POSTGRES_USER,
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+    else:
+        database_config = {
+            "name": "sqlite3",
+            "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
+        }
+
+    if "db_txn_limit" in kwargs:
+        database_config["txn_limit"] = kwargs["db_txn_limit"]
+
+    database = DatabaseConnectionConfig("master", database_config)
+    config.database.databases = [database]
+
+    db_engine = create_engine(database.config)
+
+    # Create the database before we actually try and connect to it, based off
+    # the template database we generate in setupdb()
+    if isinstance(db_engine, PostgresEngine):
+        db_conn = db_engine.module.connect(
+            database=POSTGRES_BASE_DB,
+            user=POSTGRES_USER,
+            host=POSTGRES_HOST,
+            password=POSTGRES_PASSWORD,
+        )
+        db_conn.autocommit = True
+        cur = db_conn.cursor()
+        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+        cur.execute(
+            "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
+        )
+        cur.close()
+        db_conn.close()
+
+    hs = homeserver_to_use(
+        name,
+        config=config,
+        version_string="Synapse/tests",
+        reactor=reactor,
+    )
+
+    # Install @cache_in_self attributes
+    for key, val in kwargs.items():
+        setattr(hs, "_" + key, val)
+
+    # Mock TLS
+    hs.tls_server_context_factory = Mock()
+    hs.tls_client_options_factory = Mock()
+
+    hs.setup()
+    if homeserver_to_use == TestHomeServer:
+        hs.setup_background_tasks()
+
+    if isinstance(db_engine, PostgresEngine):
+        database = hs.get_datastores().databases[0]
+
+        # We need to do cleanup on PostgreSQL
+        def cleanup():
+            import psycopg2
+
+            # Close all the db pools
+            database._db_pool.close()
+
+            dropped = False
+
+            # Drop the test database
+            db_conn = db_engine.module.connect(
+                database=POSTGRES_BASE_DB,
+                user=POSTGRES_USER,
+                host=POSTGRES_HOST,
+                password=POSTGRES_PASSWORD,
+            )
+            db_conn.autocommit = True
+            cur = db_conn.cursor()
+
+            # Try a few times to drop the DB. Some things may hold on to the
+            # database for a few more seconds due to flakiness, preventing
+            # us from dropping it when the test is over. If we can't drop
+            # it, warn and move on.
+            for _ in range(5):
+                try:
+                    cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+                    db_conn.commit()
+                    dropped = True
+                except psycopg2.OperationalError as e:
+                    warnings.warn(
+                        "Couldn't drop old db: " + str(e), category=UserWarning
+                    )
+                    time.sleep(0.5)
+
+            cur.close()
+            db_conn.close()
+
+            if not dropped:
+                warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+        if not LEAVE_DB:
+            # Register the cleanup hook
+            cleanup_func(cleanup)
+
+    # bcrypt is far too slow to be doing in unit tests
+    # Need to let the HS build an auth handler and then mess with it
+    # because AuthHandler's constructor requires the HS, so we can't make one
+    # beforehand and pass it in to the HS's constructor (chicken / egg)
+    async def hash(p):
+        return hashlib.md5(p.encode("utf8")).hexdigest()
+
+    hs.get_auth_handler().hash = hash
+
+    async def validate_hash(p, h):
+        return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+    hs.get_auth_handler().validate_hash = validate_hash
+
+    # Make the threadpool and database transactions synchronous for testing.
+    _make_test_homeserver_synchronous(hs)
+
+    return hs
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 01af49a16b..d697d2bc1e 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Iterable, Set
+from typing import Iterable, Optional, Set
 
 from synapse.api.constants import AccountDataTypes
 
@@ -25,7 +25,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         self.user = "@user:test"
 
     def _update_ignore_list(
-        self, *ignored_user_ids: Iterable[str], ignorer_user_id: str = None
+        self, *ignored_user_ids: Iterable[str], ignorer_user_id: Optional[str] = None
     ) -> None:
         """Update the account data to block the given users."""
         if ignorer_user_id is None:
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index d77c001506..6156dfac4e 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,15 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-# Use backported mock for AsyncMock support on Python 3.6.
-from mock import Mock
+from unittest.mock import Mock
 
 from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.storage.background_updates import BackgroundUpdater
 
 from tests import unittest
-from tests.test_utils import make_awaitable
+from tests.test_utils import make_awaitable, simple_async_mock
 
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
@@ -116,14 +115,14 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         )
 
         # Mock out the AsyncContextManager
-        self._update_ctx_manager = Mock(spec=["__aenter__", "__aexit__"])
-        self._update_ctx_manager.__aenter__ = Mock(
-            return_value=make_awaitable(None),
-        )
-        self._update_ctx_manager.__aexit__ = Mock(return_value=make_awaitable(None))
+        class MockCM:
+            __aenter__ = simple_async_mock(return_value=None)
+            __aexit__ = simple_async_mock(return_value=None)
+
+        self._update_ctx_manager = MockCM
 
         # Mock out the `update_handler` callback
-        self._on_update = Mock(return_value=self._update_ctx_manager)
+        self._on_update = Mock(return_value=self._update_ctx_manager())
 
         # Define a default batch size value that's not the same as the internal default
         # value (100).
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index ddad44bd6c..3e4f0579c9 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -23,7 +23,8 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.engines import create_engine
 
 from tests import unittest
-from tests.utils import TestHomeServer, default_config
+from tests.server import TestHomeServer
+from tests.utils import default_config
 
 
 class SQLBaseStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
index 9b6b425425..7556171d8a 100644
--- a/tests/storage/test_e2e_room_keys.py
+++ b/tests/storage/test_e2e_room_keys.py
@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.databases.main.e2e_room_keys import RoomKey
+
 from tests import unittest
 
 # sample room_key data for use in the tests
-room_key = {
+room_key: RoomKey = {
     "first_message_index": 1,
     "forwarded_count": 1,
     "is_verified": False,
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index c3fcf7e7b4..ecfda7677e 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -550,7 +550,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.store.db_pool.simple_select_one_onecol(
                 table="federation_inbound_events_staging",
                 keyvalues={"room_id": room_id},
-                retcol="COALESCE(COUNT(*), 0)",
+                retcol="COUNT(*)",
                 desc="test_prune_inbound_federation_queue",
             )
         )
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index bb5939ba4a..738f3ad1dc 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -14,6 +14,8 @@
 
 from unittest.mock import Mock
 
+from synapse.storage.databases.main.event_push_actions import NotifCounts
+
 from tests.unittest import HomeserverTestCase
 
 USER_ID = "@user:example.com"
@@ -57,11 +59,11 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
             )
             self.assertEquals(
                 counts,
-                {
-                    "notify_count": noitf_count,
-                    "unread_count": 0,  # Unread counts are tested in the sync tests.
-                    "highlight_count": highlight_count,
-                },
+                NotifCounts(
+                    notify_count=noitf_count,
+                    unread_count=0,  # Unread counts are tested in the sync tests.
+                    highlight_count=highlight_count,
+                ),
             )
 
         def _inject_actions(stream, action):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index fccab733c0..5cfdfe9b85 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,8 +19,8 @@ from synapse.rest.client import login, room
 from synapse.types import UserID, create_requester
 
 from tests import unittest
+from tests.server import TestHomeServer
 from tests.test_utils import event_injection
-from tests.utils import TestHomeServer
 
 
 class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 40b89fb2ef..46e02f483f 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.rest.media.v1.preview_url_resource import (
-    _calc_og,
+from synapse.rest.media.v1.preview_html import (
+    _get_html_media_encodings,
     decode_body,
-    get_html_media_encodings,
+    parse_html_to_open_graph,
     summarize_paragraphs,
 )
 
@@ -160,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -176,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -195,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(
             og,
@@ -217,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -231,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
@@ -246,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
@@ -261,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
@@ -289,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase):
         <head><title>Foo</title></head><body>Some text.</body></html>
         """.strip()
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_invalid_encoding(self):
@@ -303,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_invalid_encoding2(self):
@@ -318,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
 
     def test_windows_1252(self):
@@ -332,14 +332,14 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = _calc_og(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
 
 
 class MediaEncodingTestCase(unittest.TestCase):
     def test_meta_charset(self):
         """A character encoding is found via the meta tag."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="ascii">
@@ -351,7 +351,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
         # A less well-formed version.
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head>< meta charset = ascii>
@@ -364,7 +364,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_meta_charset_underscores(self):
         """A character encoding contains underscore."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="Shift_JIS">
@@ -377,7 +377,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_xml_encoding(self):
         """A character encoding is found via the meta tag."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="ascii"?>
         <html>
@@ -389,7 +389,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_meta_xml_encoding(self):
         """Meta tags take precedence over XML encoding."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="ascii"?>
         <html>
@@ -413,17 +413,17 @@ class MediaEncodingTestCase(unittest.TestCase):
             'text/html; charset=ascii";',
         )
         for header in headers:
-            encodings = get_html_media_encodings(b"", header)
+            encodings = _get_html_media_encodings(b"", header)
             self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
     def test_fallback(self):
         """A character encoding cannot be found in the body or header."""
-        encodings = get_html_media_encodings(b"", "text/html")
+        encodings = _get_html_media_encodings(b"", "text/html")
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
 
     def test_duplicates(self):
         """Ensure each encoding is only attempted once."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="utf8"?>
         <html>
@@ -437,7 +437,7 @@ class MediaEncodingTestCase(unittest.TestCase):
 
     def test_unknown_invalid(self):
         """A character encoding should be ignored if it is unknown or invalid."""
-        encodings = get_html_media_encodings(
+        encodings = _get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="invalid">
diff --git a/tests/unittest.py b/tests/unittest.py
index eea0903f05..1431848367 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,16 +331,13 @@ class HomeserverTestCase(TestCase):
             time.sleep(0.01)
 
     def wait_for_background_updates(self) -> None:
-        """Block until all background database updates have completed.
-
-        Note that callers must ensure there's a store property created on the
-        testcase.
-        """
+        """Block until all background database updates have completed."""
+        store = self.hs.get_datastore()
         while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
+            store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db_pool.updates.do_next_background_update(False), by=0.1
+                store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
     def make_homeserver(self, reactor, clock):
diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py
index 1e83ef2f33..025b73e32f 100644
--- a/tests/util/caches/test_response_cache.py
+++ b/tests/util/caches/test_response_cache.py
@@ -11,6 +11,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from unittest.mock import Mock
+
 from parameterized import parameterized
 
 from twisted.internet import defer
@@ -60,10 +63,15 @@ class ResponseCacheTestCase(TestCase):
             self.successResultOf(wrap_d),
             "initial wrap result should be the same",
         )
+
+        # a second call should return the result without a call to the wrapped function
+        unexpected = Mock(spec=())
+        wrap2_d = defer.ensureDeferred(cache.wrap(0, unexpected))
+        unexpected.assert_not_called()
         self.assertEqual(
             expected_result,
-            self.successResultOf(cache.get(0)),
-            "cache should have the result",
+            self.successResultOf(wrap2_d),
+            "cache should still have the result",
         )
 
     def test_cache_miss(self):
@@ -80,7 +88,7 @@ class ResponseCacheTestCase(TestCase):
             self.successResultOf(wrap_d),
             "initial wrap result should be the same",
         )
-        self.assertIsNone(cache.get(0), "cache should not have the result now")
+        self.assertCountEqual([], cache.keys(), "cache should not have the result now")
 
     def test_cache_expire(self):
         cache = self.with_cache("short_cache", ms=1000)
@@ -92,16 +100,20 @@ class ResponseCacheTestCase(TestCase):
         )
 
         self.assertEqual(expected_result, self.successResultOf(wrap_d))
+
+        # a second call should return the result without a call to the wrapped function
+        unexpected = Mock(spec=())
+        wrap2_d = defer.ensureDeferred(cache.wrap(0, unexpected))
+        unexpected.assert_not_called()
         self.assertEqual(
             expected_result,
-            self.successResultOf(cache.get(0)),
+            self.successResultOf(wrap2_d),
             "cache should still have the result",
         )
 
         # cache eviction timer is handled
         self.reactor.pump((2,))
-
-        self.assertIsNone(cache.get(0), "cache should not have the result now")
+        self.assertCountEqual([], cache.keys(), "cache should not have the result now")
 
     def test_cache_wait_hit(self):
         cache = self.with_cache("neutral_cache")
@@ -133,16 +145,21 @@ class ResponseCacheTestCase(TestCase):
         self.reactor.pump((1, 1))
 
         self.assertEqual(expected_result, self.successResultOf(wrap_d))
+
+        # a second call should immediately return the result without a call to the
+        # wrapped function
+        unexpected = Mock(spec=())
+        wrap2_d = defer.ensureDeferred(cache.wrap(0, unexpected))
+        unexpected.assert_not_called()
         self.assertEqual(
             expected_result,
-            self.successResultOf(cache.get(0)),
+            self.successResultOf(wrap2_d),
             "cache should still have the result",
         )
 
         # (1 + 1 + 2) > 3.0, cache eviction timer is handled
         self.reactor.pump((2,))
-
-        self.assertIsNone(cache.get(0), "cache should not have the result now")
+        self.assertCountEqual([], cache.keys(), "cache should not have the result now")
 
     @parameterized.expand([(True,), (False,)])
     def test_cache_context_nocache(self, should_cache: bool):
@@ -183,10 +200,16 @@ class ResponseCacheTestCase(TestCase):
         self.assertEqual(expected_result, self.successResultOf(wrap2_d))
 
         if should_cache:
+            unexpected = Mock(spec=())
+            wrap3_d = defer.ensureDeferred(cache.wrap(0, unexpected))
+            unexpected.assert_not_called()
             self.assertEqual(
                 expected_result,
-                self.successResultOf(cache.get(0)),
+                self.successResultOf(wrap3_d),
                 "cache should still have the result",
             )
+
         else:
-            self.assertIsNone(cache.get(0), "cache should not have the result")
+            self.assertCountEqual(
+                [], cache.keys(), "cache should not have the result now"
+            )
diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py
deleted file mode 100644
index 220accb92b..0000000000
--- a/tests/util/test_glob_to_regex.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2021 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from synapse.util import glob_to_regex
-
-from tests.unittest import TestCase
-
-
-class GlobToRegexTestCase(TestCase):
-    def test_literal_match(self):
-        """patterns without wildcards should match"""
-        pat = glob_to_regex("foobaz")
-        self.assertTrue(
-            pat.match("FoobaZ"), "patterns should match and be case-insensitive"
-        )
-        self.assertFalse(
-            pat.match("x foobaz"), "pattern should not match at word boundaries"
-        )
-
-    def test_wildcard_match(self):
-        pat = glob_to_regex("f?o*baz")
-
-        self.assertTrue(
-            pat.match("FoobarbaZ"),
-            "* should match string and pattern should be case-insensitive",
-        )
-        self.assertTrue(pat.match("foobaz"), "* should match 0 characters")
-        self.assertFalse(pat.match("fooxaz"), "the character after * must match")
-        self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters")
-        self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters")
-
-    def test_multi_wildcard(self):
-        """patterns with multiple wildcards in a row should match"""
-        pat = glob_to_regex("**baz")
-        self.assertTrue(pat.match("agsgsbaz"), "** should match any string")
-        self.assertTrue(pat.match("baz"), "** should match the empty string")
-        self.assertEqual(pat.pattern, r"\A.{0,}baz\Z")
-
-        pat = glob_to_regex("*?baz")
-        self.assertTrue(pat.match("agsgsbaz"), "*? should match any string")
-        self.assertTrue(pat.match("abaz"), "*? should match a single char")
-        self.assertFalse(pat.match("baz"), "*? should not match the empty string")
-        self.assertEqual(pat.pattern, r"\A.{1,}baz\Z")
-
-        pat = glob_to_regex("a?*?*?baz")
-        self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars")
-        self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars")
-        self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars")
-        self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z")
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 5d9c4665aa..621b0f9fcd 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase):
             # now it should be restored
             self._check_test_key("one")
 
-    @defer.inlineCallbacks
-    def test_make_deferred_yieldable_on_non_deferred(self):
-        """Check that make_deferred_yieldable does the right thing when its
-        argument isn't actually a deferred"""
-
-        with LoggingContext("one"):
-            d1 = make_deferred_yieldable("bum")
-            self._check_test_key("one")
-
-            r = yield d1
-            self.assertEqual(r, "bum")
-            self._check_test_key("one")
-
     def test_nested_logging_context(self):
         with LoggingContext("foo"):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.name, "foo-bar")
 
-    @defer.inlineCallbacks
-    def test_make_deferred_yieldable_with_await(self):
-        # an async function which returns an incomplete coroutine, but doesn't
-        # follow the synapse rules.
-
-        async def blocking_function():
-            d = defer.Deferred()
-            reactor.callLater(0, d.callback, None)
-            await d
-
-        sentinel_context = current_context()
-
-        with LoggingContext("one"):
-            d1 = make_deferred_yieldable(blocking_function())
-            # make sure that the context was reset by make_deferred_yieldable
-            self.assertIs(current_context(), sentinel_context)
-
-            yield d1
-
-            # now it should be restored
-            self._check_test_key("one")
-
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
diff --git a/tests/utils.py b/tests/utils.py
index 983859120f..6d013e8518 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -14,12 +14,7 @@
 # limitations under the License.
 
 import atexit
-import hashlib
 import os
-import time
-import uuid
-import warnings
-from typing import Type
 from unittest.mock import Mock, patch
 from urllib import parse as urlparse
 
@@ -28,14 +23,11 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
-from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.logging.context import current_context, set_current_context
-from synapse.server import HomeServer
-from synapse.storage import DataStore
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import PostgresEngine, create_engine
+from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 
 # set this to True to run the tests against postgres instead of sqlite.
@@ -182,171 +174,6 @@ def default_config(name, parse=False):
     return config_dict
 
 
-class TestHomeServer(HomeServer):
-    DATASTORE_CLASS = DataStore
-
-
-def setup_test_homeserver(
-    cleanup_func,
-    name="test",
-    config=None,
-    reactor=None,
-    homeserver_to_use: Type[HomeServer] = TestHomeServer,
-    **kwargs,
-):
-    """
-    Setup a homeserver suitable for running tests against.  Keyword arguments
-    are passed to the Homeserver constructor.
-
-    If no datastore is supplied, one is created and given to the homeserver.
-
-    Args:
-        cleanup_func : The function used to register a cleanup routine for
-                       after the test.
-
-    Calling this method directly is deprecated: you should instead derive from
-    HomeserverTestCase.
-    """
-    if reactor is None:
-        from twisted.internet import reactor
-
-    if config is None:
-        config = default_config(name, parse=True)
-
-    config.ldap_enabled = False
-
-    if "clock" not in kwargs:
-        kwargs["clock"] = MockClock()
-
-    if USE_POSTGRES_FOR_TESTS:
-        test_db = "synapse_test_%s" % uuid.uuid4().hex
-
-        database_config = {
-            "name": "psycopg2",
-            "args": {
-                "database": test_db,
-                "host": POSTGRES_HOST,
-                "password": POSTGRES_PASSWORD,
-                "user": POSTGRES_USER,
-                "cp_min": 1,
-                "cp_max": 5,
-            },
-        }
-    else:
-        database_config = {
-            "name": "sqlite3",
-            "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
-        }
-
-    if "db_txn_limit" in kwargs:
-        database_config["txn_limit"] = kwargs["db_txn_limit"]
-
-    database = DatabaseConnectionConfig("master", database_config)
-    config.database.databases = [database]
-
-    db_engine = create_engine(database.config)
-
-    # Create the database before we actually try and connect to it, based off
-    # the template database we generate in setupdb()
-    if isinstance(db_engine, PostgresEngine):
-        db_conn = db_engine.module.connect(
-            database=POSTGRES_BASE_DB,
-            user=POSTGRES_USER,
-            host=POSTGRES_HOST,
-            password=POSTGRES_PASSWORD,
-        )
-        db_conn.autocommit = True
-        cur = db_conn.cursor()
-        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
-        cur.execute(
-            "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
-        )
-        cur.close()
-        db_conn.close()
-
-    hs = homeserver_to_use(
-        name,
-        config=config,
-        version_string="Synapse/tests",
-        reactor=reactor,
-    )
-
-    # Install @cache_in_self attributes
-    for key, val in kwargs.items():
-        setattr(hs, "_" + key, val)
-
-    # Mock TLS
-    hs.tls_server_context_factory = Mock()
-    hs.tls_client_options_factory = Mock()
-
-    hs.setup()
-    if homeserver_to_use == TestHomeServer:
-        hs.setup_background_tasks()
-
-    if isinstance(db_engine, PostgresEngine):
-        database = hs.get_datastores().databases[0]
-
-        # We need to do cleanup on PostgreSQL
-        def cleanup():
-            import psycopg2
-
-            # Close all the db pools
-            database._db_pool.close()
-
-            dropped = False
-
-            # Drop the test database
-            db_conn = db_engine.module.connect(
-                database=POSTGRES_BASE_DB,
-                user=POSTGRES_USER,
-                host=POSTGRES_HOST,
-                password=POSTGRES_PASSWORD,
-            )
-            db_conn.autocommit = True
-            cur = db_conn.cursor()
-
-            # Try a few times to drop the DB. Some things may hold on to the
-            # database for a few more seconds due to flakiness, preventing
-            # us from dropping it when the test is over. If we can't drop
-            # it, warn and move on.
-            for _ in range(5):
-                try:
-                    cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
-                    db_conn.commit()
-                    dropped = True
-                except psycopg2.OperationalError as e:
-                    warnings.warn(
-                        "Couldn't drop old db: " + str(e), category=UserWarning
-                    )
-                    time.sleep(0.5)
-
-            cur.close()
-            db_conn.close()
-
-            if not dropped:
-                warnings.warn("Failed to drop old DB.", category=UserWarning)
-
-        if not LEAVE_DB:
-            # Register the cleanup hook
-            cleanup_func(cleanup)
-
-    # bcrypt is far too slow to be doing in unit tests
-    # Need to let the HS build an auth handler and then mess with it
-    # because AuthHandler's constructor requires the HS, so we can't make one
-    # beforehand and pass it in to the HS's constructor (chicken / egg)
-    async def hash(p):
-        return hashlib.md5(p.encode("utf8")).hexdigest()
-
-    hs.get_auth_handler().hash = hash
-
-    async def validate_hash(p, h):
-        return hashlib.md5(p.encode("utf8")).hexdigest() == h
-
-    hs.get_auth_handler().validate_hash = validate_hash
-
-    return hs
-
-
 def mock_getRawHeaders(headers=None):
     headers = headers if headers is not None else {}
 
diff --git a/tox.ini b/tox.ini
index cfe6a06942..2ffca14b22 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
 [tox]
-envlist = packaging, py36, py37, py38, py39, check_codestyle, check_isort
+envlist = packaging, py37, py38, py39, py310, check_codestyle, check_isort
 
 # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208
 minversion = 2.3.2