summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-03-28 13:54:02 +0100
committerBrendan Abolivier <babolivier@matrix.org>2022-03-28 13:54:02 +0100
commit25507bffc67c40e83cbcd4a79fdfee3667855a7c (patch)
tree5620b2a06a5a9894ac875ddcf3b232db45cae48d /tests
parentMerge branch 'develop' of github.com:matrix-org/synapse into babolivier/sign_... (diff)
parentAdd restrictions by default to open registration in Synapse (#12091) (diff)
downloadsynapse-github/babolivier/sign_json_module.tar.xz
Merge branch 'develop' into babolivier/sign_json_module github/babolivier/sign_json_module babolivier/sign_json_module
Diffstat (limited to 'tests')
-rw-r--r--tests/appservice/test_appservice.py45
-rw-r--r--tests/config/test_background_update.py58
-rw-r--r--tests/config/test_registration_config.py22
-rw-r--r--tests/events/test_utils.py5
-rw-r--r--tests/federation/test_federation_sender.py52
-rw-r--r--tests/handlers/test_admin.py18
-rw-r--r--tests/handlers/test_appservice.py56
-rw-r--r--tests/handlers/test_auth.py34
-rw-r--r--tests/handlers/test_cas.py19
-rw-r--r--tests/handlers/test_deactivate_account.py2
-rw-r--r--tests/handlers/test_device.py76
-rw-r--r--tests/handlers/test_directory.py84
-rw-r--r--tests/handlers/test_e2e_keys.py36
-rw-r--r--tests/handlers/test_federation.py36
-rw-r--r--tests/handlers/test_oidc.py94
-rw-r--r--tests/handlers/test_password_providers.py1
-rw-r--r--tests/handlers/test_presence.py13
-rw-r--r--tests/handlers/test_profile.py49
-rw-r--r--tests/handlers/test_room_summary.py3
-rw-r--r--tests/handlers/test_saml.py24
-rw-r--r--tests/handlers/test_typing.py35
-rw-r--r--tests/module_api/test_api.py10
-rw-r--r--tests/push/test_http.py106
-rw-r--r--tests/push/test_push_rule_evaluator.py23
-rw-r--r--tests/replication/_base.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py2
-rw-r--r--tests/replication/test_federation_ack.py2
-rw-r--r--tests/rest/admin/test_background_updates.py21
-rw-r--r--tests/rest/admin/test_user.py19
-rw-r--r--tests/rest/client/test_account.py58
-rw-r--r--tests/rest/client/test_mutual_rooms.py (renamed from tests/rest/client/test_shared_rooms.py)30
-rw-r--r--tests/rest/client/test_relations.py1536
-rw-r--r--tests/rest/client/test_retention.py29
-rw-r--r--tests/rest/client/test_rooms.py18
-rw-r--r--tests/rest/client/test_third_party_rules.py121
-rw-r--r--tests/rest/client/test_transactions.py19
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py44
-rw-r--r--tests/rest/media/v1/test_base.py4
-rw-r--r--tests/rest/media/v1/test_filepath.py48
-rw-r--r--tests/rest/media/v1/test_html_preview.py102
-rw-r--r--tests/rest/media/v1/test_media_storage.py110
-rw-r--r--tests/rest/media/v1/test_oembed.py10
-rw-r--r--tests/rest/media/v1/test_url_preview.py83
-rw-r--r--tests/rest/test_health.py8
-rw-r--r--tests/rest/test_well_known.py20
-rw-r--r--tests/server.py20
-rw-r--r--tests/storage/test_account_data.py17
-rw-r--r--tests/storage/test_background_update.py273
-rw-r--r--tests/storage/test_database.py176
-rw-r--r--tests/storage/test_id_generators.py80
-rw-r--r--tests/storage/test_stream.py20
-rw-r--r--tests/storage/test_unsafe_locale.py46
-rw-r--r--tests/test_visibility.py76
-rw-r--r--tests/util/caches/test_descriptors.py231
-rw-r--r--tests/util/test_async_helpers.py152
-rw-r--r--tests/util/test_check_dependencies.py53
-rw-r--r--tests/util/test_rwlock.py395
-rw-r--r--tests/utils.py110
59 files changed, 3330 insertions, 1510 deletions
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 9bd6275e92..edc584d0cf 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
             hostname="matrix.org",  # only used by get_groups_for_user
         )
         self.event = Mock(
-            type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
+            event_id="$abc:xyz",
+            type="m.something",
+            room_id="!foo:bar",
+            sender="@someone:somewhere",
         )
 
         self.store = Mock()
@@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(event=self.event, store=self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
diff --git a/tests/config/test_background_update.py b/tests/config/test_background_update.py
new file mode 100644
index 0000000000..0c32c1ca29
--- /dev/null
+++ b/tests/config/test_background_update.py
@@ -0,0 +1,58 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import yaml
+
+from synapse.storage.background_updates import BackgroundUpdater
+
+from tests.unittest import HomeserverTestCase, override_config
+
+
+class BackgroundUpdateConfigTestCase(HomeserverTestCase):
+    # Tests that the default values in the config are correctly loaded. Note that the default
+    # values are loaded when the corresponding config options are commented out, which is why there isn't
+    # a config specified here.
+    def test_default_configuration(self):
+        background_updater = BackgroundUpdater(
+            self.hs, self.hs.get_datastores().main.db_pool
+        )
+
+        self.assertEqual(background_updater.minimum_background_batch_size, 1)
+        self.assertEqual(background_updater.default_background_batch_size, 100)
+        self.assertEqual(background_updater.sleep_enabled, True)
+        self.assertEqual(background_updater.sleep_duration_ms, 1000)
+        self.assertEqual(background_updater.update_duration_ms, 100)
+
+    # Tests that non-default values for the config options are properly picked up and passed on.
+    @override_config(
+        yaml.safe_load(
+            """
+            background_updates:
+                background_update_duration_ms: 1000
+                sleep_enabled: false
+                sleep_duration_ms: 600
+                min_batch_size: 5
+                default_batch_size: 50
+            """
+        )
+    )
+    def test_custom_configuration(self):
+        background_updater = BackgroundUpdater(
+            self.hs, self.hs.get_datastores().main.db_pool
+        )
+
+        self.assertEqual(background_updater.minimum_background_batch_size, 5)
+        self.assertEqual(background_updater.default_background_batch_size, 50)
+        self.assertEqual(background_updater.sleep_enabled, False)
+        self.assertEqual(background_updater.sleep_duration_ms, 600)
+        self.assertEqual(background_updater.update_duration_ms, 1000)
diff --git a/tests/config/test_registration_config.py b/tests/config/test_registration_config.py
index 17a84d20d8..2acdb6ac61 100644
--- a/tests/config/test_registration_config.py
+++ b/tests/config/test_registration_config.py
@@ -11,14 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+import synapse.app.homeserver
 from synapse.config import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 
-from tests.unittest import TestCase
+from tests.config.utils import ConfigFileTestCase
 from tests.utils import default_config
 
 
-class RegistrationConfigTestCase(TestCase):
+class RegistrationConfigTestCase(ConfigFileTestCase):
     def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
         """
         session_lifetime should logically be larger than, or at least as large as,
@@ -76,3 +78,19 @@ class RegistrationConfigTestCase(TestCase):
         HomeServerConfig().parse_config_dict(
             {"session_lifetime": "31m", "refresh_token_lifetime": "31m", **config_dict}
         )
+
+    def test_refuse_to_start_if_open_registration_and_no_verification(self):
+        self.generate_config()
+        self.add_lines_to_config(
+            [
+                " ",
+                "enable_registration: true",
+                "registrations_require_3pid: []",
+                "enable_registration_captcha: false",
+                "registration_requires_token: false",
+            ]
+        )
+
+        # Test that allowing open registration without verification raises an error
+        with self.assertRaises(ConfigError):
+            synapse.app.homeserver.setup(["-c", self.config_file])
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index 45e3395b33..00ad19e446 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -16,6 +16,7 @@ from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
 from synapse.events.utils import (
+    SerializeEventConfig,
     copy_power_levels_contents,
     prune_event,
     serialize_event,
@@ -392,7 +393,9 @@ class PruneEventTestCase(unittest.TestCase):
 
 class SerializeEventTestCase(unittest.TestCase):
     def serialize(self, ev, fields):
-        return serialize_event(ev, 1479807801915, only_event_fields=fields)
+        return serialize_event(
+            ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
+        )
 
     def test_event_fields_works_with_keys(self):
         self.assertEqual(
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 60e0c31f43..e90592855a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -201,9 +201,12 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.assertEqual(len(self.edus), 1)
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         # a second call should produce no new device EDUs
         self.hs.get_federation_sender().send_device_messages("host2")
-        self.pump()
         self.assertEqual(self.edus, [])
 
         # a second device
@@ -232,6 +235,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
         device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         # expect two more edus
         self.assertEqual(len(self.edus), 2)
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
@@ -265,6 +272,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
         )
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         # expect signing key update edu
         self.assertEqual(len(self.edus), 2)
         self.assertEqual(self.edus.pop(0)["edu_type"], "m.signing_key_update")
@@ -284,6 +295,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         )
         self.assertEqual(ret["failures"], {})
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         # expect two edus, in one or two transactions. We don't know what order the
         # devices will be updated.
         self.assertEqual(len(self.edus), 2)
@@ -307,6 +322,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.login("user", "pass", device_id="D2")
         self.login("user", "pass", device_id="D3")
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         # expect three edus
         self.assertEqual(len(self.edus), 3)
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
@@ -318,6 +337,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
         )
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         # expect three edus, in an unknown order
         self.assertEqual(len(self.edus), 3)
         for edu in self.edus:
@@ -350,12 +373,19 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
         )
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         self.assertGreaterEqual(mock_send_txn.call_count, 4)
 
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
         self.hs.get_federation_sender().send_device_messages("host2")
-        self.pump()
+
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
 
         # for each device, there should be a single update
         self.assertEqual(len(self.edus), 3)
@@ -390,6 +420,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
         )
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         self.assertGreaterEqual(mock_send_txn.call_count, 4)
 
         # run the prune job
@@ -401,7 +435,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
         self.hs.get_federation_sender().send_device_messages("host2")
-        self.pump()
+
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
 
         # there should be a single update for this user.
         self.assertEqual(len(self.edus), 1)
@@ -435,6 +472,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.login("user", "pass", device_id="D2")
         self.login("user", "pass", device_id="D3")
 
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
+
         # delete them again
         self.get_success(
             self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
@@ -451,7 +492,10 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         # recover the server
         mock_send_txn.side_effect = self.record_transaction
         self.hs.get_federation_sender().send_device_messages("host2")
-        self.pump()
+
+        # We queue up device list updates to be sent over federation, so we
+        # advance to clear the queue.
+        self.reactor.advance(1)
 
         # ... and we should get a single update for this user.
         self.assertEqual(len(self.edus), 1)
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index abf2a0fe0d..c1579dac61 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -15,11 +15,15 @@
 from collections import Counter
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 import synapse.storage
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.room_versions import RoomVersions
 from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -32,7 +36,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         knock.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.admin_handler = hs.get_admin_handler()
 
         self.user1 = self.register_user("user1", "password")
@@ -41,7 +45,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.user2 = self.register_user("user2", "password")
         self.token2 = self.login("user2", "password")
 
-    def test_single_public_joined_room(self):
+    def test_single_public_joined_room(self) -> None:
         """Test that we write *all* events for a public room"""
         room_id = self.helper.create_room_as(
             self.user1, tok=self.token1, is_public=True
@@ -74,7 +78,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
 
-    def test_single_private_joined_room(self):
+    def test_single_private_joined_room(self) -> None:
         """Tests that we correctly write state when we can't see all events in
         a room.
         """
@@ -112,7 +116,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
 
-    def test_single_left_room(self):
+    def test_single_left_room(self) -> None:
         """Tests that we don't see events in the room after we leave."""
         room_id = self.helper.create_room_as(self.user1, tok=self.token1)
         self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -144,7 +148,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
 
-    def test_single_left_rejoined_private_room(self):
+    def test_single_left_rejoined_private_room(self) -> None:
         """Tests that see the correct events in private rooms when we
         repeatedly join and leave.
         """
@@ -185,7 +189,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
 
-    def test_invite(self):
+    def test_invite(self) -> None:
         """Tests that pending invites get handled correctly."""
         room_id = self.helper.create_room_as(self.user1, tok=self.token1)
         self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -204,7 +208,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(args[1].content["membership"], "invite")
         self.assertTrue(args[2])  # Assert there is at least one bit of state
 
-    def test_knock(self):
+    def test_knock(self) -> None:
         """Tests that knock get handled correctly."""
         # create a knockable v7 room
         room_id = self.helper.create_room_as(
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 072e6bbcdd..cead9f90df 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -59,11 +59,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.event_source = hs.get_event_sources()
 
     def test_notify_interested_services(self):
-        interested_service = self._mkservice(is_interested=True)
+        interested_service = self._mkservice(is_interested_in_event=True)
         services = [
-            self._mkservice(is_interested=False),
+            self._mkservice(is_interested_in_event=False),
             interested_service,
-            self._mkservice(is_interested=False),
+            self._mkservice(is_interested_in_event=False),
         ]
 
         self.mock_as_api.query_user.return_value = make_awaitable(True)
@@ -85,7 +85,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def test_query_user_exists_unknown_user(self):
         user_id = "@someone:anywhere"
-        services = [self._mkservice(is_interested=True)]
+        services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
         self.mock_store.get_user_by_id.return_value = make_awaitable(None)
@@ -102,7 +102,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def test_query_user_exists_known_user(self):
         user_id = "@someone:anywhere"
-        services = [self._mkservice(is_interested=True)]
+        services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
         self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
@@ -127,11 +127,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
         room_id = "!alpha:bet"
         servers = ["aperture"]
-        interested_service = self._mkservice_alias(is_interested_in_alias=True)
+        interested_service = self._mkservice_alias(is_room_alias_in_namespace=True)
         services = [
-            self._mkservice_alias(is_interested_in_alias=False),
+            self._mkservice_alias(is_room_alias_in_namespace=False),
             interested_service,
-            self._mkservice_alias(is_interested_in_alias=False),
+            self._mkservice_alias(is_room_alias_in_namespace=False),
         ]
 
         self.mock_as_api.query_alias.return_value = make_awaitable(True)
@@ -275,7 +275,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         to be pushed out to interested appservices, and that the stream ID is
         updated accordingly.
         """
-        interested_service = self._mkservice(is_interested=True)
+        interested_service = self._mkservice(is_interested_in_event=True)
         services = [interested_service]
         self.mock_store.get_app_services.return_value = services
         self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         Test sending out of order ephemeral events to the appservice handler
         are ignored.
         """
-        interested_service = self._mkservice(is_interested=True)
+        interested_service = self._mkservice(is_interested_in_event=True)
         services = [interested_service]
 
         self.mock_store.get_app_services.return_value = services
@@ -325,17 +325,45 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             interested_service, ephemeral=[]
         )
 
-    def _mkservice(self, is_interested, protocols=None):
+    def _mkservice(
+        self, is_interested_in_event: bool, protocols: Optional[Iterable] = None
+    ) -> Mock:
+        """
+        Create a new mock representing an ApplicationService.
+
+        Args:
+            is_interested_in_event: Whether this application service will be considered
+                interested in all events.
+            protocols: The third-party protocols that this application service claims to
+                support.
+
+        Returns:
+            A mock representing the ApplicationService.
+        """
         service = Mock()
-        service.is_interested.return_value = make_awaitable(is_interested)
+        service.is_interested_in_event.return_value = make_awaitable(
+            is_interested_in_event
+        )
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         service.protocols = protocols
         return service
 
-    def _mkservice_alias(self, is_interested_in_alias):
+    def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock:
+        """
+        Create a new mock representing an ApplicationService that is or is not interested
+        any given room aliase.
+
+        Args:
+            is_room_alias_in_namespace: If true, the application service will be interested
+                in all room aliases that are queried against it. If false, the application
+                service will not be interested in any room aliases.
+
+        Returns:
+            A mock representing the ApplicationService.
+        """
         service = Mock()
-        service.is_interested_in_alias.return_value = is_interested_in_alias
+        service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0c6e55e725..67a7829769 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -15,8 +15,12 @@ from unittest.mock import Mock
 
 import pymacaroons
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.errors import AuthError, ResourceLimitError
 from synapse.rest import admin
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -27,7 +31,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         admin.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.auth_handler = hs.get_auth_handler()
         self.macaroon_generator = hs.get_macaroon_generator()
 
@@ -42,23 +46,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.user1 = self.register_user("a_user", "pass")
 
-    def test_macaroon_caveats(self):
+    def test_macaroon_caveats(self) -> None:
         token = self.macaroon_generator.generate_guest_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        def verify_gen(caveat):
+        def verify_gen(caveat: str) -> bool:
             return caveat == "gen = 1"
 
-        def verify_user(caveat):
+        def verify_user(caveat: str) -> bool:
             return caveat == "user_id = a_user"
 
-        def verify_type(caveat):
+        def verify_type(caveat: str) -> bool:
             return caveat == "type = access"
 
-        def verify_nonce(caveat):
+        def verify_nonce(caveat: str) -> bool:
             return caveat.startswith("nonce =")
 
-        def verify_guest(caveat):
+        def verify_guest(caveat: str) -> bool:
             return caveat == "guest = true"
 
         v = pymacaroons.Verifier()
@@ -69,7 +73,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         v.satisfy_general(verify_guest)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-    def test_short_term_login_token_gives_user_id(self):
+    def test_short_term_login_token_gives_user_id(self) -> None:
         token = self.macaroon_generator.generate_short_term_login_token(
             self.user1, "", duration_in_ms=5000
         )
@@ -84,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             AuthError,
         )
 
-    def test_short_term_login_token_gives_auth_provider(self):
+    def test_short_term_login_token_gives_auth_provider(self) -> None:
         token = self.macaroon_generator.generate_short_term_login_token(
             self.user1, auth_provider_id="my_idp"
         )
@@ -92,7 +96,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.user1, res.user_id)
         self.assertEqual("my_idp", res.auth_provider_id)
 
-    def test_short_term_login_token_cannot_replace_user_id(self):
+    def test_short_term_login_token_cannot_replace_user_id(self) -> None:
         token = self.macaroon_generator.generate_short_term_login_token(
             self.user1, "", duration_in_ms=5000
         )
@@ -112,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             AuthError,
         )
 
-    def test_mau_limits_disabled(self):
+    def test_mau_limits_disabled(self) -> None:
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
         self.get_success(
@@ -127,7 +131,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-    def test_mau_limits_exceeded_large(self):
+    def test_mau_limits_exceeded_large(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
@@ -150,7 +154,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
 
-    def test_mau_limits_parity(self):
+    def test_mau_limits_parity(self) -> None:
         # Ensure we're not at the unix epoch.
         self.reactor.advance(1)
         self.auth_blocking._limit_usage_by_mau = True
@@ -189,7 +193,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-    def test_mau_limits_not_exceeded(self):
+    def test_mau_limits_not_exceeded(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
 
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -211,7 +215,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-    def _get_macaroon(self):
+    def _get_macaroon(self) -> pymacaroons.Macaroon:
         token = self.macaroon_generator.generate_short_term_login_token(
             self.user1, "", duration_in_ms=5000
         )
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a267228846..a54aa29cf1 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -11,9 +11,14 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
+from typing import Any, Dict
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.handlers.cas import CasResponse
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -24,7 +29,7 @@ SERVER_URL = "https://issuer/"
 
 
 class CasHandlerTestCase(HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
         cas_config = {
@@ -40,7 +45,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         return config
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver()
 
         self.handler = hs.get_cas_handler()
@@ -51,7 +56,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         return hs
 
-    def test_map_cas_user_to_user(self):
+    def test_map_cas_user_to_user(self) -> None:
         """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
 
         # stub out the auth handler
@@ -75,7 +80,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
         )
 
-    def test_map_cas_user_to_existing_user(self):
+    def test_map_cas_user_to_existing_user(self) -> None:
         """Existing users can log in with CAS account."""
         store = self.hs.get_datastores().main
         self.get_success(
@@ -119,7 +124,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
         )
 
-    def test_map_cas_user_to_invalid_localpart(self):
+    def test_map_cas_user_to_invalid_localpart(self) -> None:
         """CAS automaps invalid characters to base-64 encoding."""
 
         # stub out the auth handler
@@ -150,7 +155,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_required_attributes(self):
+    def test_required_attributes(self) -> None:
         """The required attributes must be met from the CAS response."""
 
         # stub out the auth handler
@@ -166,7 +171,7 @@ class CasHandlerTestCase(HomeserverTestCase):
         auth_handler.complete_sso_login.assert_not_called()
 
         # The response doesn't have any department.
-        cas_response = CasResponse("test_user", {"userGroup": "staff"})
+        cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
         request.reset_mock()
         self.get_success(
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index ddda36c5a9..3a10791226 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
         self.user = self.register_user("user", "pass")
         self.token = self.login("user", "pass")
 
-    def _deactivate_my_account(self):
+    def _deactivate_my_account(self) -> None:
         """
         Deactivates the account `self.user` using `self.token` and asserts
         that it returns a 200 success code.
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 683677fd07..01ea7d2a42 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -14,9 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import synapse.api.errors
-import synapse.handlers.device
-import synapse.storage
+from typing import Optional
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.errors import NotFoundError, SynapseError
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -25,28 +30,27 @@ user2 = "@theresa:bbb"
 
 
 class DeviceTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.store = hs.get_datastores().main
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # These tests assume that it starts 1000 seconds in.
         self.reactor.advance(1000)
 
-    def test_device_is_created_with_invalid_name(self):
+    def test_device_is_created_with_invalid_name(self) -> None:
         self.get_failure(
             self.handler.check_device_registered(
                 user_id="@boris:foo",
                 device_id="foo",
-                initial_device_display_name="a"
-                * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
+                initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
             ),
-            synapse.api.errors.SynapseError,
+            SynapseError,
         )
 
-    def test_device_is_created_if_doesnt_exist(self):
+    def test_device_is_created_if_doesnt_exist(self) -> None:
         res = self.get_success(
             self.handler.check_device_registered(
                 user_id="@boris:foo",
@@ -59,7 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
         self.assertEqual(dev["display_name"], "display name")
 
-    def test_device_is_preserved_if_exists(self):
+    def test_device_is_preserved_if_exists(self) -> None:
         res1 = self.get_success(
             self.handler.check_device_registered(
                 user_id="@boris:foo",
@@ -81,7 +85,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
         self.assertEqual(dev["display_name"], "display name")
 
-    def test_device_id_is_made_up_if_unspecified(self):
+    def test_device_id_is_made_up_if_unspecified(self) -> None:
         device_id = self.get_success(
             self.handler.check_device_registered(
                 user_id="@theresa:foo",
@@ -93,7 +97,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
         self.assertEqual(dev["display_name"], "display")
 
-    def test_get_devices_by_user(self):
+    def test_get_devices_by_user(self) -> None:
         self._record_users()
 
         res = self.get_success(self.handler.get_devices_by_user(user1))
@@ -131,7 +135,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             device_map["abc"],
         )
 
-    def test_get_device(self):
+    def test_get_device(self) -> None:
         self._record_users()
 
         res = self.get_success(self.handler.get_device(user1, "abc"))
@@ -146,21 +150,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             res,
         )
 
-    def test_delete_device(self):
+    def test_delete_device(self) -> None:
         self._record_users()
 
         # delete the device
         self.get_success(self.handler.delete_device(user1, "abc"))
 
         # check the device was deleted
-        self.get_failure(
-            self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
-        )
+        self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
 
         # we'd like to check the access token was invalidated, but that's a
         # bit of a PITA.
 
-    def test_delete_device_and_device_inbox(self):
+    def test_delete_device_and_device_inbox(self) -> None:
         self._record_users()
 
         # add an device_inbox
@@ -191,7 +193,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         )
         self.assertIsNone(res)
 
-    def test_update_device(self):
+    def test_update_device(self) -> None:
         self._record_users()
 
         update = {"display_name": "new display"}
@@ -200,32 +202,29 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         res = self.get_success(self.handler.get_device(user1, "abc"))
         self.assertEqual(res["display_name"], "new display")
 
-    def test_update_device_too_long_display_name(self):
+    def test_update_device_too_long_display_name(self) -> None:
         """Update a device with a display name that is invalid (too long)."""
         self._record_users()
 
         # Request to update a device display name with a new value that is longer than allowed.
-        update = {
-            "display_name": "a"
-            * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
-        }
+        update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
         self.get_failure(
             self.handler.update_device(user1, "abc", update),
-            synapse.api.errors.SynapseError,
+            SynapseError,
         )
 
         # Ensure the display name was not updated.
         res = self.get_success(self.handler.get_device(user1, "abc"))
         self.assertEqual(res["display_name"], "display 2")
 
-    def test_update_unknown_device(self):
+    def test_update_unknown_device(self) -> None:
         update = {"display_name": "new_display"}
         self.get_failure(
             self.handler.update_device("user_id", "unknown_device_id", update),
-            synapse.api.errors.NotFoundError,
+            NotFoundError,
         )
 
-    def _record_users(self):
+    def _record_users(self) -> None:
         # check this works for both devices which have a recorded client_ip,
         # and those which don't.
         self._record_user(user1, "xyz", "display 0")
@@ -238,8 +237,13 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10000)
 
     def _record_user(
-        self, user_id, device_id, display_name, access_token=None, ip=None
-    ):
+        self,
+        user_id: str,
+        device_id: str,
+        display_name: str,
+        access_token: Optional[str] = None,
+        ip: Optional[str] = None,
+    ) -> None:
         device_id = self.get_success(
             self.handler.check_device_registered(
                 user_id=user_id,
@@ -248,7 +252,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        if ip is not None:
+        if access_token is not None and ip is not None:
             self.get_success(
                 self.store.insert_client_ip(
                     user_id, access_token, ip, "user_agent", device_id
@@ -258,7 +262,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.registration = hs.get_registration_handler()
@@ -266,7 +270,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         return hs
 
-    def test_dehydrate_and_rehydrate_device(self):
+    def test_dehydrate_and_rehydrate_device(self) -> None:
         user_id = "@boris:dehydration"
 
         self.get_success(self.store.register_user(user_id, "foobar"))
@@ -303,7 +307,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
                 access_token=access_token,
                 device_id="not the right device ID",
             ),
-            synapse.api.errors.NotFoundError,
+            NotFoundError,
         )
 
         # dehydrating the right devices should succeed and change our device ID
@@ -331,7 +335,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         # make sure that the device ID that we were initially assigned no longer exists
         self.get_failure(
             self.handler.get_device(user_id, device_id),
-            synapse.api.errors.NotFoundError,
+            NotFoundError,
         )
 
         # make sure that there's no device available for dehydrating now
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 6e403a87c5..11ad44223d 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -12,14 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from typing import Any, Awaitable, Callable, Dict
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.api.errors
 import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.rest.client import directory, login, room
-from synapse.types import RoomAlias, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomAlias, create_requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
 class DirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the directory service."""
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.mock_federation = Mock()
         self.mock_registry = Mock()
 
-        self.query_handlers = {}
+        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
 
-        def register_query_handler(query_type, handler):
+        def register_query_handler(
+            query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+        ) -> None:
             self.query_handlers[query_type] = handler
 
         self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_get_local_association(self):
+    def test_get_local_association(self) -> None:
         self.get_success(
             self.store.create_room_alias_association(
                 self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
-    def test_get_remote_association(self):
+    def test_get_remote_association(self) -> None:
         self.mock_federation.make_query.return_value = make_awaitable(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
         )
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
             ignore_backoff=True,
         )
 
-    def test_incoming_fed_query(self):
+    def test_incoming_fed_query(self) -> None:
         self.get_success(
             self.store.create_room_alias_association(
                 self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         directory.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_directory_handler()
 
         # Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         self.test_user_tok = self.login("user", "pass")
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
 
-    def test_create_alias_joined_room(self):
+    def test_create_alias_joined_room(self) -> None:
         """A user can create an alias for a room they're in."""
         self.get_success(
             self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
             )
         )
 
-    def test_create_alias_other_room(self):
+    def test_create_alias_other_room(self) -> None:
         """A user cannot create an alias for a room they're NOT in."""
         other_room_id = self.helper.create_room_as(
             self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
         )
 
-    def test_create_alias_admin(self):
+    def test_create_alias_admin(self) -> None:
         """An admin can create an alias for a room they're NOT in."""
         other_room_id = self.helper.create_room_as(
             self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
         directory.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
         self.test_user_tok = self.login("user", "pass")
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
 
-    def _create_alias(self, user):
+    def _create_alias(self, user) -> None:
         # Create a new alias to this room.
         self.get_success(
             self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             )
         )
 
-    def test_delete_alias_not_allowed(self):
+    def test_delete_alias_not_allowed(self) -> None:
         """A user that doesn't meet the expected guidelines cannot delete an alias."""
         self._create_alias(self.admin_user)
         self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.AuthError,
         )
 
-    def test_delete_alias_creator(self):
+    def test_delete_alias_creator(self) -> None:
         """An alias creator can delete their own alias."""
         # Create an alias from a different user.
         self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
         )
 
-    def test_delete_alias_admin(self):
+    def test_delete_alias_admin(self) -> None:
         """A server admin can delete an alias created by another user."""
         # Create an alias from a different user.
         self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
         )
 
-    def test_delete_alias_sufficient_power(self):
+    def test_delete_alias_sufficient_power(self) -> None:
         """A user with a sufficient power level should be able to delete an alias."""
         self._create_alias(self.admin_user)
 
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         directory.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
         return room_alias
 
-    def _set_canonical_alias(self, content):
+    def _set_canonical_alias(self, content) -> None:
         """Configure the canonical alias state on the room."""
         self.helper.send_state(
             self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
             )
         )
 
-    def test_remove_alias(self):
+    def test_remove_alias(self) -> None:
         """Removing an alias that is the canonical alias should remove it there too."""
         # Set this new alias as the canonical alias for this room
         self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("alias", data["content"])
         self.assertNotIn("alt_aliases", data["content"])
 
-    def test_remove_other_alias(self):
+    def test_remove_other_alias(self) -> None:
         """Removing an alias listed as in alt_aliases should remove it there too."""
         # Create a second alias.
         other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
 
     servlets = [directory.register_servlets, room.register_servlets]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
 
         return config
 
-    def test_denied(self):
+    def test_denied(self) -> None:
         room_id = self.helper.create_room_as(self.user_id)
 
         channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         )
         self.assertEqual(403, channel.code, channel.result)
 
-    def test_allowed(self):
+    def test_allowed(self) -> None:
         room_id = self.helper.create_room_as(self.user_id)
 
         channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         )
         self.assertEqual(200, channel.code, channel.result)
 
-    def test_denied_during_creation(self):
+    def test_denied_during_creation(self) -> None:
         """A room alias that is not allowed should be rejected during creation."""
         # Invalid room alias.
         self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
             extra_content={"room_alias_name": "foo"},
         )
 
-    def test_allowed_during_creation(self):
+    def test_allowed_during_creation(self) -> None:
         """A valid room alias should be allowed during creation."""
         room_id = self.helper.create_room_as(
             self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
     data = {"room_alias_name": "unofficial_test"}
     allowed_localpart = "allowed"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+    ) -> HomeServer:
         self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
         self.allowed_access_token = self.login(self.allowed_localpart, "pass")
 
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_denied_without_publication_permission(self):
+    def test_denied_without_publication_permission(self) -> None:
         """
         Try to create a room, register an alias for it, and publish it,
         as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=403,
         )
 
-    def test_allowed_when_creating_private_room(self):
+    def test_allowed_when_creating_private_room(self) -> None:
         """
         Try to create a room, register an alias for it, and NOT publish it,
         as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=200,
         )
 
-    def test_allowed_with_publication_permission(self):
+    def test_allowed_with_publication_permission(self) -> None:
         """
         Try to create a room, register an alias for it, and publish it,
         as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=200,
         )
 
-    def test_denied_publication_with_invalid_alias(self):
+    def test_denied_publication_with_invalid_alias(self) -> None:
         """
         Try to create a room, register an alias for it, and publish it,
         as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=403,
         )
 
-    def test_can_create_as_private_room_after_rejection(self):
+    def test_can_create_as_private_room_after_rejection(self) -> None:
         """
         After failing to publish a room with an alias as a user without publish permission,
         retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
         self.test_denied_without_publication_permission()
         self.test_allowed_when_creating_private_room()
 
-    def test_can_create_with_permission_after_rejection(self):
+    def test_can_create_with_permission_after_rejection(self) -> None:
         """
         After failing to publish a room with an alias as a user without publish permission,
         retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
     servlets = [directory.register_servlets, room.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+    ) -> HomeServer:
         room_id = self.helper.create_room_as(self.user_id)
 
         channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_disabling_room_list(self):
+    def test_disabling_room_list(self) -> None:
         self.room_list_handler.enable_room_list_search = True
         self.directory_handler.enable_room_list_search = True
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 9338ab92e9..ac21a28c43 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -20,33 +20,37 @@ from parameterized import parameterized
 from signedjson import key as key, sign as sign
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         return self.setup_test_homeserver(federation_client=mock.Mock())
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastores().main
 
-    def test_query_local_devices_no_devices(self):
+    def test_query_local_devices_no_devices(self) -> None:
         """If the user has no devices, we expect an empty list."""
         local_user = "@boris:" + self.hs.hostname
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    def test_reupload_one_time_keys(self):
+    def test_reupload_one_time_keys(self) -> None:
         """we should be able to re-upload the same keys"""
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
-        keys = {
+        keys: JsonDict = {
             "alg1:k1": "key1",
             "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
             "alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
         )
 
-    def test_change_one_time_keys(self):
+    def test_change_one_time_keys(self) -> None:
         """attempts to change one-time-keys should be rejected"""
 
         local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-    def test_claim_one_time_key(self):
+    def test_claim_one_time_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_fallback_key(self):
+    def test_fallback_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
-    def test_replace_master_key(self):
+    def test_replace_master_key(self) -> None:
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
-    def test_reupload_signatures(self):
+    def test_reupload_signatures(self) -> None:
         """re-uploading a signature should not fail"""
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
 
-    def test_self_signing_key_doesnt_show_up_as_device(self):
+    def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
         """signing keys should be hidden when fetching a user's devices"""
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    def test_upload_signatures(self):
+    def test_upload_signatures(self) -> None:
         """should check signatures that are uploaded"""
         # set up a user with cross-signing keys and a device.  This user will
         # try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
 
-    def test_query_devices_remote_no_sync(self):
+    def test_query_devices_remote_no_sync(self) -> None:
         """Tests that querying keys for a remote user that we don't share a room
         with returns the cross signing keys correctly.
         """
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_query_devices_remote_sync(self):
+    def test_query_devices_remote_sync(self) -> None:
         """Tests that querying keys for a remote user that we share a room with,
         but haven't yet fetched the keys for, returns the cross signing keys
         correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             (["device_1", "device_2"],),
         ]
     )
-    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
         """Test that requests for all of a remote user's devices are cached.
 
         We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         """
         local_user_id = "@test:test"
         remote_user_id = "@test:other"
-        request_body = {"device_keys": {remote_user_id: []}}
+        request_body: JsonDict = {"device_keys": {remote_user_id: []}}
 
         response_devices = [
             {
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e8b4e39d1a..89078fc637 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,9 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List
+from typing import List, cast
 from unittest import TestCase
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
@@ -23,7 +25,9 @@ from synapse.federation.federation_base import event_from_pdu_json
 from synapse.logging.context import LoggingContext, run_in_background
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
 from synapse.types import create_requester
+from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -42,7 +46,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastores().main
@@ -50,7 +54,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self._event_auth_handler = hs.get_event_auth_handler()
         return hs
 
-    def test_exchange_revoked_invite(self):
+    def test_exchange_revoked_invite(self) -> None:
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
 
@@ -96,7 +100,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
         self.assertEqual(failure.msg, "You are not invited to this room.")
 
-    def test_rejected_message_event_state(self):
+    def test_rejected_message_event_state(self) -> None:
         """
         Check that we store the state group correctly for rejected non-state events.
 
@@ -126,7 +130,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 "content": {},
                 "room_id": room_id,
                 "sender": "@yetanotheruser:" + OTHER_SERVER,
-                "depth": join_event["depth"] + 1,
+                "depth": cast(int, join_event["depth"]) + 1,
                 "prev_events": [join_event.event_id],
                 "auth_events": [],
                 "origin_server_ts": self.clock.time_msec(),
@@ -149,7 +153,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
-    def test_rejected_state_event_state(self):
+    def test_rejected_state_event_state(self) -> None:
         """
         Check that we store the state group correctly for rejected state events.
 
@@ -180,7 +184,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 "content": {},
                 "room_id": room_id,
                 "sender": "@yetanotheruser:" + OTHER_SERVER,
-                "depth": join_event["depth"] + 1,
+                "depth": cast(int, join_event["depth"]) + 1,
                 "prev_events": [join_event.event_id],
                 "auth_events": [],
                 "origin_server_ts": self.clock.time_msec(),
@@ -203,7 +207,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
-    def test_backfill_with_many_backward_extremities(self):
+    def test_backfill_with_many_backward_extremities(self) -> None:
         """
         Check that we can backfill with many backward extremities.
         The goal is to make sure that when we only use a portion
@@ -262,7 +266,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             )
         self.get_success(d)
 
-    def test_backfill_floating_outlier_membership_auth(self):
+    def test_backfill_floating_outlier_membership_auth(self) -> None:
         """
         As the local homeserver, check that we can properly process a federated
         event from the OTHER_SERVER with auth_events that include a floating
@@ -377,7 +381,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
                 for ae in auth_events
             ]
 
-        self.handler.federation_client.get_event_auth = get_event_auth
+        self.handler.federation_client.get_event_auth = get_event_auth  # type: ignore[assignment]
 
         with LoggingContext("receive_pdu"):
             # Fake the OTHER_SERVER federating the message event over to our local homeserver
@@ -397,7 +401,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     @unittest.override_config(
         {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
     )
-    def test_invite_by_user_ratelimit(self):
+    def test_invite_by_user_ratelimit(self) -> None:
         """Tests that invites from federation to a particular user are
         actually rate-limited.
         """
@@ -446,7 +450,9 @@ class FederationTestCase(unittest.HomeserverTestCase):
             exc=LimitExceededError,
         )
 
-    def _build_and_send_join_event(self, other_server, other_user, room_id):
+    def _build_and_send_join_event(
+        self, other_server: str, other_user: str, room_id: str
+    ) -> EventBase:
         join_event = self.get_success(
             self.handler.on_make_join_request(other_server, room_id, other_user)
         )
@@ -469,7 +475,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
 
 class EventFromPduTestCase(TestCase):
-    def test_valid_json(self):
+    def test_valid_json(self) -> None:
         """Valid JSON should be turned into an event."""
         ev = event_from_pdu_json(
             {
@@ -487,7 +493,7 @@ class EventFromPduTestCase(TestCase):
 
         self.assertIsInstance(ev, EventBase)
 
-    def test_invalid_numbers(self):
+    def test_invalid_numbers(self) -> None:
         """Invalid values for an integer should be rejected, all floats should be rejected."""
         for value in [
             -(2 ** 53),
@@ -512,7 +518,7 @@ class EventFromPduTestCase(TestCase):
                     RoomVersions.V6,
                 )
 
-    def test_invalid_nested(self):
+    def test_invalid_nested(self) -> None:
         """List and dictionaries are recursively searched."""
         with self.assertRaises(SynapseError):
             event_from_pdu_json(
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
 # limitations under the License.
 import json
 import os
+from typing import Any, Dict
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
 import pymacaroons
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
         # Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
     elif url == JWKS_URI:
         return {"keys": []}
 
+    return {}
+
 
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
     if not HAS_OIDC:
         skip = "requires OIDC"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
         return config
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock(spec=["get_json"])
         self.http_client.get_json.side_effect = get_json
         self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
-        sso_handler.render_error = self.render_error
+        sso_handler.render_error = self.render_error  # type: ignore[assignment]
 
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return args
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_config(self):
+    def test_config(self) -> None:
         """Basic config correctly sets up the callback URL and client auth correctly."""
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
-    def test_discovery(self):
+    def test_discovery(self) -> None:
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_no_discovery(self):
+    def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_load_jwks(self):
+    def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_validate_config(self):
+    def test_validate_config(self) -> None:
         """Provider metadatas are extensively validated."""
         h = self.provider
 
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             force_load_metadata()
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
-    def test_skip_verification(self):
+    def test_skip_verification(self) -> None:
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
             get_awaitable_result(self.provider.load_metadata())
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_redirect_request(self):
+    def test_redirect_request(self) -> None:
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["cookies"])
         req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(redirect, "http://client/redirect")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_error(self):
+    def test_callback_error(self) -> None:
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
         request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_client", "some description")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback(self):
+    def test_callback(self) -> None:
         """Code callback works and display errors if something went wrong.
 
         A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "type": "bearer",
             "access_token": "access_token",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         id_token = {
             "sid": "abcdefgh",
         }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         auth_handler.complete_sso_login.reset_mock()
         self.provider._fetch_userinfo.reset_mock()
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
         from synapse.handlers.oidc import OidcError
 
-        self.provider._exchange_code = simple_async_mock(
+        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_session(self):
+    def test_callback_session(self) -> None:
         """The callback verifies the session presence and validity"""
         request = Mock(spec=["args", "getCookie", "cookies"])
 
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
     )
-    def test_exchange_code(self):
+    def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
         token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_exchange_code_jwt_key(self):
+    def test_exchange_code_jwt_key(self) -> None:
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
 
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_exchange_code_no_auth(self):
+    def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
         token = {"type": "bearer"}
         self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_extra_attributes(self):
+    def test_extra_attributes(self) -> None:
         """
         Login while using a mapping provider that implements get_extra_attributes.
         """
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "foo",
             "phone": "1234567",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_user(self):
+    def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
-        userinfo = {
+        userinfo: dict = {
             "sub": "test_user",
             "username": "test_user",
         }
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
-    def test_map_userinfo_to_existing_user(self):
+    def test_map_userinfo_to_existing_user(self) -> None:
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastores().main
         user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_invalid_localpart(self):
+    def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
         self.get_success(
             _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_map_userinfo_to_user_retries(self):
+    def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_empty_localpart(self):
+    def test_empty_localpart(self) -> None:
         """Attempts to map onto an empty localpart should be rejected."""
         userinfo = {
             "sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_null_localpart(self):
+    def test_null_localpart(self) -> None:
         """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
         userinfo = {
             "sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements(self):
+    def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements_contains(self):
+    def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements_mismatch(self):
+    def test_attribute_requirements_mismatch(self) -> None:
         """
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
-        userinfo = {
+        userinfo: dict = {
             "sub": "tester",
             "username": "tester",
             "test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
 
     handler = hs.get_oidc_handler()
     provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
 
     state = "state"
     session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 49d832de81..d401fda938 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -124,7 +124,6 @@ class PasswordCustomAuthProvider:
                 ("m.login.password", ("password",)): self.check_auth,
             }
         )
-        pass
 
     def check_auth(self, *args):
         return mock_password_provider.check_auth(*args)
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 6ddec9ecf1..b2ed9cbe37 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -331,11 +331,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase):
 
         # Extract presence update user ID and state information into lists of tuples
         db_presence_states = [(ps[0], ps[1]) for _, ps in db_presence_states[0]]
-        presence_states = [(ps.user_id, ps.state) for ps in presence_states]
+        presence_states_compare = [(ps.user_id, ps.state) for ps in presence_states]
 
         # Compare what we put into the storage with what we got out.
         # They should be identical.
-        self.assertEqual(presence_states, db_presence_states)
+        self.assertEqual(presence_states_compare, db_presence_states)
 
 
 class PresenceTimeoutTestCase(unittest.TestCase):
@@ -357,6 +357,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.UNAVAILABLE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -380,6 +381,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.BUSY)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -399,6 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.OFFLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -420,6 +423,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         )
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.ONLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -477,6 +481,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         )
 
         self.assertIsNotNone(new_state)
+        assert new_state is not None
         self.assertEqual(new_state.state, PresenceState.OFFLINE)
         self.assertEqual(new_state.status_msg, status_msg)
 
@@ -653,13 +658,13 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
         self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None)
 
     def _set_presencestate_with_status_msg(
-        self, user_id: str, state: PresenceState, status_msg: Optional[str]
+        self, user_id: str, state: str, status_msg: Optional[str]
     ):
         """Set a PresenceState and status_msg and check the result.
 
         Args:
             user_id: User for that the status is to be set.
-            PresenceState: The new PresenceState.
+            state: The new PresenceState.
             status_msg: Status message that is to be set.
         """
         self.get_success(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 972cbac6e4..1ec105c373 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -11,14 +11,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict
+from typing import Any, Awaitable, Callable, Dict
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
 from synapse.rest import admin
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
     servlets = [admin.register_servlets]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.mock_federation = Mock()
         self.mock_registry = Mock()
 
-        self.query_handlers = {}
+        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
 
-        def register_query_handler(query_type, handler):
+        def register_query_handler(
+            query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+        ) -> None:
             self.query_handlers[query_type] = handler
 
         self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         return hs
 
-    def prepare(self, reactor, clock, hs: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
         self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.handler = hs.get_profile_handler()
 
-    def test_get_my_name(self):
+    def test_get_my_name(self) -> None:
         self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual("Frank", displayname)
 
-    def test_set_my_name(self):
+    def test_set_my_name(self) -> None:
         self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             self.get_success(self.store.get_profile_displayname(self.frank.localpart))
         )
 
-    def test_set_my_name_if_disabled(self):
+    def test_set_my_name_if_disabled(self) -> None:
         self.hs.config.registration.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-    def test_set_my_name_noauth(self):
+    def test_set_my_name_noauth(self) -> None:
         self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             AuthError,
         )
 
-    def test_get_other_name(self):
+    def test_get_other_name(self) -> None:
         self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             ignore_backoff=True,
         )
 
-    def test_incoming_fed_query(self):
+    def test_incoming_fed_query(self) -> None:
         self.get_success(self.store.create_profile("caroline"))
         self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
 
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual({"displayname": "Caroline"}, response)
 
-    def test_get_my_avatar(self):
+    def test_get_my_avatar(self) -> None:
         self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual("http://my.server/me.png", avatar_url)
 
-    def test_set_my_avatar(self):
+    def test_set_my_avatar(self) -> None:
         self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
         )
 
-    def test_set_my_avatar_if_disabled(self):
+    def test_set_my_avatar_if_disabled(self) -> None:
         self.hs.config.registration.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-    def test_avatar_constraints_no_config(self):
+    def test_avatar_constraints_no_config(self) -> None:
         """Tests that the method to check an avatar against configured constraints skips
         all of its check if no constraint is configured.
         """
@@ -263,7 +268,13 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertTrue(res)
 
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_constraints_missing(self):
+    def test_avatar_constraints_allow_empty_avatar_url(self) -> None:
+        """An empty avatar is always permitted."""
+        res = self.get_success(self.handler.check_avatar_size_and_mime_type(""))
+        self.assertTrue(res)
+
+    @unittest.override_config({"max_avatar_size": 50})
+    def test_avatar_constraints_missing(self) -> None:
         """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
         be found.
         """
@@ -273,7 +284,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertFalse(res)
 
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_constraints_file_size(self):
+    def test_avatar_constraints_file_size(self) -> None:
         """Tests that a file that's above the allowed file size is forbidden but one
         that's below it is allowed.
         """
@@ -295,7 +306,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertFalse(res)
 
     @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
-    def test_avatar_constraint_mime_type(self):
+    def test_avatar_constraint_mime_type(self) -> None:
         """Tests that a file with an unauthorised MIME type is forbidden but one with
         an authorised content type is allowed.
         """
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index cff07a8973..d37292ce13 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -172,6 +172,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         result_room_ids = []
         result_children_ids = []
         for result_room in result["rooms"]:
+            # Ensure federation results are not leaking over the client-server API.
+            self.assertNotIn("allowed_room_ids", result_room)
+
             result_room_ids.append(result_room["room_id"])
             result_children_ids.append(
                 [
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 23941abed8..8d4404eda1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,12 +12,16 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
-from typing import Optional
+from typing import Any, Dict, Optional
 from unittest.mock import Mock
 
 import attr
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.errors import RedirectException
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
 
 
 class SamlHandlerTestCase(HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
-        saml_config = {
+        saml_config: Dict[str, Any] = {
             "sp_config": {"metadata": {}},
             # Disable grandfathering.
             "grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         return config
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver()
 
         self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
     elif not has_xmlsec1:
         skip = "Requires xmlsec1"
 
-    def test_map_saml_response_to_user(self):
+    def test_map_saml_response_to_user(self) -> None:
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
 
         # stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
-    def test_map_saml_response_to_existing_user(self):
+    def test_map_saml_response_to_existing_user(self) -> None:
         """Existing users can log in with SAML account."""
         store = self.hs.get_datastores().main
         self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
         )
 
-    def test_map_saml_response_to_invalid_localpart(self):
+    def test_map_saml_response_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
 
         # stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         auth_handler.complete_sso_login.assert_not_called()
 
-    def test_map_saml_response_to_user_retries(self):
+    def test_map_saml_response_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
 
         # stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_map_saml_response_redirect(self):
+    def test_map_saml_response_redirect(self) -> None:
         """Test a mapping provider that raises a RedirectException"""
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             },
         }
     )
-    def test_attribute_requirements(self):
+    def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the SAML response."""
 
         # stub out the auth handler
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f91a80b9fa..ffd5c4cb93 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -18,11 +18,14 @@ from typing import Dict
 from unittest.mock import ANY, Mock, call
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.api.errors import AuthError
 from synapse.federation.transport.server import TransportLayerServer
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
@@ -42,7 +45,9 @@ ROOM_ID = "a-room"
 OTHER_ROOM_ID = "another-room"
 
 
-def _expect_edu_transaction(edu_type, content, origin="test"):
+def _expect_edu_transaction(
+    edu_type: str, content: JsonDict, origin: str = "test"
+) -> JsonDict:
     return {
         "origin": origin,
         "origin_server_ts": 1000000,
@@ -51,12 +56,12 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
     }
 
 
-def _make_edu_transaction_json(edu_type, content):
+def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
     return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
         mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -83,7 +88,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         d["/_matrix/federation"] = TransportLayerServer(self.hs)
         return d
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event
 
@@ -111,24 +116,24 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.room_members = []
 
-        async def check_user_in_room(room_id, user_id):
+        async def check_user_in_room(room_id: str, user_id: str) -> None:
             if user_id not in [u.to_string() for u in self.room_members]:
                 raise AuthError(401, "User is not in the room")
             return None
 
         hs.get_auth().check_user_in_room = check_user_in_room
 
-        async def check_host_in_room(room_id, server_name):
+        async def check_host_in_room(room_id: str, server_name: str) -> bool:
             return room_id == ROOM_ID
 
         hs.get_event_auth_handler().check_host_in_room = check_host_in_room
 
-        def get_joined_hosts_for_room(room_id):
+        def get_joined_hosts_for_room(room_id: str):
             return {member.domain for member in self.room_members}
 
         self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
 
-        async def get_users_in_room(room_id):
+        async def get_users_in_room(room_id: str):
             return {str(u) for u in self.room_members}
 
         self.datastore.get_users_in_room = get_users_in_room
@@ -153,7 +158,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             lambda *args, **kwargs: make_awaitable(None)
         )
 
-    def test_started_typing_local(self):
+    def test_started_typing_local(self) -> None:
         self.room_members = [U_APPLE, U_BANANA]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
@@ -187,7 +192,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         )
 
     @override_config({"send_federation": True})
-    def test_started_typing_remote_send(self):
+    def test_started_typing_remote_send(self) -> None:
         self.room_members = [U_APPLE, U_ONION]
 
         self.get_success(
@@ -217,7 +222,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             try_trailing_slash_on_400=True,
         )
 
-    def test_started_typing_remote_recv(self):
+    def test_started_typing_remote_recv(self) -> None:
         self.room_members = [U_APPLE, U_ONION]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
@@ -256,7 +261,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             ],
         )
 
-    def test_started_typing_remote_recv_not_in_room(self):
+    def test_started_typing_remote_recv_not_in_room(self) -> None:
         self.room_members = [U_APPLE, U_ONION]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
@@ -292,7 +297,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(events[1], 0)
 
     @override_config({"send_federation": True})
-    def test_stopped_typing(self):
+    def test_stopped_typing(self) -> None:
         self.room_members = [U_APPLE, U_BANANA, U_ONION]
 
         # Gut-wrenching
@@ -343,7 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
         )
 
-    def test_typing_timeout(self):
+    def test_typing_timeout(self) -> None:
         self.room_members = [U_APPLE, U_BANANA]
 
         self.assertEqual(self.event_source.get_current_key(), 0)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index c3f20f9692..10dd94b549 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -86,6 +86,16 @@ class ModuleApiTestCase(HomeserverTestCase):
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
 
+    def test_can_register_admin_user(self):
+        user_id = self.get_success(
+            self.register_user(
+                "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
+            )
+        )
+        found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        self.assertEqual(found_user.user_id.to_string(), user_id)
+        self.assertIdentical(found_user.is_admin, True)
+
     def test_get_userinfo_by_id(self):
         user_id = self.register_user("alice", "1234")
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index c284beb37c..ba158f5d93 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,14 +11,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, List, Optional, Tuple
 from unittest.mock import Mock
 
 from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
 from synapse.push import PusherConfigException
-from synapse.rest.client import login, receipts, room
+from synapse.rest.client import login, push_rule, receipts, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -29,22 +34,23 @@ class HTTPPusherTests(HomeserverTestCase):
         room.register_servlets,
         login.register_servlets,
         receipts.register_servlets,
+        push_rule.register_servlets,
     ]
     user_id = True
     hijack_auth = False
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["start_pushers"] = True
         return config
 
-    def make_homeserver(self, reactor, clock):
-        self.push_attempts = []
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.push_attempts: List[Tuple[Deferred, str, dict]] = []
 
         m = Mock()
 
         def post_json_get_json(url, body):
-            d = Deferred()
+            d: Deferred = Deferred()
             self.push_attempts.append((d, url, body))
             return make_deferred_yieldable(d)
 
@@ -54,7 +60,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         return hs
 
-    def test_invalid_configuration(self):
+    def test_invalid_configuration(self) -> None:
         """Invalid push configurations should be rejected."""
         # Register the user who gets notified
         user_id = self.register_user("user", "pass")
@@ -66,7 +72,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         token_id = user_tuple.token_id
 
-        def test_data(data):
+        def test_data(data: Optional[JsonDict]) -> None:
             self.get_failure(
                 self.hs.get_pusherpool().add_pusher(
                     user_id=user_id,
@@ -93,7 +99,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # A url with an incorrect path isn't accepted.
         test_data({"url": "http://example.com/foo"})
 
-    def test_sends_http(self):
+    def test_sends_http(self) -> None:
         """
         The HTTP pusher will send pushes for each message to a HTTP endpoint
         when configured to do so.
@@ -198,7 +204,7 @@ class HTTPPusherTests(HomeserverTestCase):
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
-    def test_sends_high_priority_for_encrypted(self):
+    def test_sends_high_priority_for_encrypted(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to an encrypted message.
@@ -319,7 +325,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
 
-    def test_sends_high_priority_for_one_to_one_only(self):
+    def test_sends_high_priority_for_one_to_one_only(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to a message in a one-to-one room.
@@ -402,7 +408,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
 
-    def test_sends_high_priority_for_mention(self):
+    def test_sends_high_priority_for_mention(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to a message containing the user's display name.
@@ -478,7 +484,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
 
-    def test_sends_high_priority_for_atroom(self):
+    def test_sends_high_priority_for_atroom(self) -> None:
         """
         The HTTP pusher will send pushes at high priority if they correspond
         to a message that contains @room.
@@ -561,7 +567,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # check that this is low-priority
         self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
 
-    def test_push_unread_count_group_by_room(self):
+    def test_push_unread_count_group_by_room(self) -> None:
         """
         The HTTP pusher will group unread count by number of unread rooms.
         """
@@ -574,7 +580,7 @@ class HTTPPusherTests(HomeserverTestCase):
         self._check_push_attempt(6, 1)
 
     @override_config({"push": {"group_unread_count_by_room": False}})
-    def test_push_unread_count_message_count(self):
+    def test_push_unread_count_message_count(self) -> None:
         """
         The HTTP pusher will send the total unread message count.
         """
@@ -587,7 +593,7 @@ class HTTPPusherTests(HomeserverTestCase):
         # last read receipt
         self._check_push_attempt(6, 3)
 
-    def _test_push_unread_count(self):
+    def _test_push_unread_count(self) -> None:
         """
         Tests that the correct unread count appears in sent push notifications
 
@@ -679,7 +685,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         self.helper.send(room_id, body="HELLO???", tok=other_access_token)
 
-    def _advance_time_and_make_push_succeed(self, expected_push_attempts):
+    def _advance_time_and_make_push_succeed(self, expected_push_attempts: int) -> None:
         self.pump()
         self.push_attempts[expected_push_attempts - 1][0].callback({})
 
@@ -706,7 +712,9 @@ class HTTPPusherTests(HomeserverTestCase):
             expected_unread_count_last_push,
         )
 
-    def _send_read_request(self, access_token, message_event_id, room_id):
+    def _send_read_request(
+        self, access_token: str, message_event_id: str, room_id: str
+    ) -> None:
         # Now set the user's read receipt position to the first event
         #
         # This will actually trigger a new notification to be sent out so that
@@ -719,3 +727,67 @@ class HTTPPusherTests(HomeserverTestCase):
             access_token=access_token,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
+
+    def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
+        user_id = self.register_user(username, "pass")
+        access_token = self.login(username, "pass")
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple.token_id
+
+        self.get_success(
+            self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="http",
+                app_id="m.http",
+                app_display_name="HTTP Push Notifications",
+                device_display_name="pushy push",
+                pushkey="a@example.com",
+                lang=None,
+                data={"url": "http://example.com/_matrix/push/v1/notify"},
+            )
+        )
+
+        return user_id, access_token
+
+    def test_dont_notify_rule_overrides_message(self) -> None:
+        """
+        The override push rule will suppress notification
+        """
+
+        user_id, access_token = self._make_user_with_pusher("user")
+        other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+        # Create a room
+        room = self.helper.create_room_as(user_id, tok=access_token)
+
+        # Disable user notifications for this room -> user
+        body = {
+            "conditions": [{"kind": "event_match", "key": "room_id", "pattern": room}],
+            "actions": ["dont_notify"],
+        }
+        channel = self.make_request(
+            "PUT",
+            "/pushrules/global/override/best.friend",
+            body,
+            access_token=access_token,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Check we start with no pushes
+        self.assertEqual(len(self.push_attempts), 0)
+
+        # The other user joins
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # The other user sends a message (ignored by dont_notify push rule set above)
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 0)
+
+        # The user sends a message back (sends a notification)
+        self.helper.send(room, body="Hello", tok=access_token)
+        self.assertEqual(len(self.push_attempts), 1)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 3849beb9d6..5dba187076 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict
+from typing import Dict, Optional, Union
 
 import frozendict
 
@@ -20,12 +20,13 @@ from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent
 from synapse.push import push_rule_evaluator
 from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
+from synapse.types import JsonDict
 
 from tests import unittest
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def _get_evaluator(self, content):
+    def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluatorForEvent:
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -39,12 +40,12 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         )
         room_member_count = 0
         sender_power_level = 0
-        power_levels = {}
+        power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
         return PushRuleEvaluatorForEvent(
             event, room_member_count, sender_power_level, power_levels
         )
 
-    def test_display_name(self):
+    def test_display_name(self) -> None:
         """Check for a matching display name in the body of the event."""
         evaluator = self._get_evaluator({"body": "foo bar baz"})
 
@@ -71,20 +72,20 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
 
     def _assert_matches(
-        self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+        self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
     ) -> None:
         evaluator = self._get_evaluator(content)
         self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
 
     def _assert_not_matches(
-        self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
+        self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
     ) -> None:
         evaluator = self._get_evaluator(content)
         self.assertFalse(
             evaluator.matches(condition, "@user:test", "display_name"), msg
         )
 
-    def test_event_match_body(self):
+    def test_event_match_body(self) -> None:
         """Check that event_match conditions on content.body work as expected"""
 
         # if the key is `content.body`, the pattern matches substrings.
@@ -165,7 +166,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             r"? after \ should match any character",
         )
 
-    def test_event_match_non_body(self):
+    def test_event_match_non_body(self) -> None:
         """Check that event_match conditions on other keys work as expected"""
 
         # if the key is anything other than 'content.body', the pattern must match the
@@ -241,7 +242,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             "pattern should not match before a newline",
         )
 
-    def test_no_body(self):
+    def test_no_body(self) -> None:
         """Not having a body shouldn't break the evaluator."""
         evaluator = self._get_evaluator({})
 
@@ -250,7 +251,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         }
         self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
 
-    def test_invalid_body(self):
+    def test_invalid_body(self) -> None:
         """A non-string body should not break the evaluator."""
         condition = {
             "kind": "contains_display_name",
@@ -260,7 +261,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             evaluator = self._get_evaluator({"body": body})
             self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
 
-    def test_tweaks_for_actions(self):
+    def test_tweaks_for_actions(self) -> None:
         """
         This tests the behaviour of tweaks_for_actions.
         """
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a7a05a564f..9c5df266bd 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -251,7 +251,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
                 self.connect_any_redis_attempts,
             )
 
-            self.hs.get_tcp_replication().start_replication(self.hs)
+            self.hs.get_replication_command_handler().start_replication(self.hs)
 
         # When we see a connection attempt to the master replication listener we
         # automatically set up the connection. This is so that tests don't
@@ -375,7 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         )
 
         if worker_hs.config.redis.redis_enabled:
-            worker_hs.get_tcp_replication().start_replication(worker_hs)
+            worker_hs.get_replication_command_handler().start_replication(worker_hs)
 
         return worker_hs
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index f9d5da723c..641a94133b 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -420,7 +420,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         # Manually send an old RDATA command, which should get dropped. This
         # re-uses the row from above, but with an earlier stream token.
-        self.hs.get_tcp_replication().send_command(
+        self.hs.get_replication_command_handler().send_command(
             RdataCommand("events", "master", 1, row)
         )
 
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 3ff5afc6e5..9a229dd23f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -118,7 +118,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
         # Reset the typing handler
         self.hs.get_replication_streams()["typing"].last_token = 0
-        self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+        self.hs.get_replication_command_handler()._streams["typing"].last_token = 0
         typing._latest_room_serial = 0
         typing._typing_stream_change_cache = StreamChangeCache(
             "TypingStreamChangeCache", typing._latest_room_serial
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 1b6a4bf4b0..26b8bd512a 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -48,7 +48,7 @@ class FederationAckTestCase(HomeserverTestCase):
         transport, rather than assuming that the implementation has a
         ReplicationCommandHandler.
         """
-        rch = self.hs.get_tcp_replication()
+        rch = self.hs.get_replication_command_handler()
 
         # wire up the ReplicationCommandHandler to a mock connection, which needs
         # to implement IReplicationConnection. (Note that Mock doesn't understand
diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py
index fb36aa9940..6cf56b1e35 100644
--- a/tests/rest/admin/test_background_updates.py
+++ b/tests/rest/admin/test_background_updates.py
@@ -39,6 +39,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
+        self.updater = BackgroundUpdater(hs, self.store.db_pool)
 
     @parameterized.expand(
         [
@@ -135,10 +136,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
         """Test the status API works with a background update."""
 
         # Create a new background update
-
         self._register_bg_update()
 
         self.store.db_pool.updates.start_doing_background_updates()
+
         self.reactor.pump([1.0, 1.0, 1.0])
 
         channel = self.make_request(
@@ -155,10 +156,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.1,
                         "total_duration_ms": 1000.0,
                         "total_item_count": (
-                            BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+                            self.updater.default_background_batch_size
                         ),
                     }
                 },
@@ -210,10 +211,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.1,
                         "total_duration_ms": 1000.0,
                         "total_item_count": (
-                            BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+                            self.updater.default_background_batch_size
                         ),
                     }
                 },
@@ -239,10 +240,10 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.1,
                         "total_duration_ms": 1000.0,
                         "total_item_count": (
-                            BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
+                            self.updater.default_background_batch_size
                         ),
                     }
                 },
@@ -278,11 +279,9 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
                 "current_updates": {
                     "master": {
                         "name": "test_update",
-                        "average_items_per_ms": 0.001,
+                        "average_items_per_ms": 0.05263157894736842,
                         "total_duration_ms": 2000.0,
-                        "total_item_count": (
-                            2 * BackgroundUpdater.MINIMUM_BACKGROUND_BATCH_SIZE
-                        ),
+                        "total_item_count": (110),
                     }
                 },
                 "enabled": True,
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index a60ea0a563..bef911d5df 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1050,6 +1050,25 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
 
         self._is_erased("@user:test", True)
 
+    @override_config({"max_avatar_size": 1234})
+    def test_deactivate_user_erase_true_avatar_nonnull_but_empty(self) -> None:
+        """Check we can erase a user whose avatar is the empty string.
+
+        Reproduces #12257.
+        """
+        # Patch `self.other_user` to have an empty string as their avatar.
+        self.get_success(self.store.set_profile_avatar_url("user", ""))
+
+        # Check we can still erase them.
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"erase": True},
+        )
+        self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
+        self._is_erased("@user:test", True)
+
     def test_deactivate_user_erase_false(self) -> None:
         """
         Test deactivating a user and set `erase` to `false`
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index def836054d..27946febff 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -31,7 +31,7 @@ from synapse.rest import admin
 from synapse.rest.client import account, login, register, room
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 
 from tests import unittest
@@ -1222,6 +1222,62 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_failures=[users[2]],
         )
 
+    @unittest.override_config(
+        {
+            "use_account_validity_in_account_status": True,
+        }
+    )
+    def test_no_account_validity(self) -> None:
+        """Tests that if we decide to include account validity in the response but no
+        account validity 'is_user_expired' callback is provided, we default to marking all
+        users as not expired.
+        """
+        user = self.register_user("someuser", "password")
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                    "org.matrix.expired": False,
+                },
+            },
+            expected_failures=[],
+        )
+
+    @unittest.override_config(
+        {
+            "use_account_validity_in_account_status": True,
+        }
+    )
+    def test_account_validity_expired(self) -> None:
+        """Test that if we decide to include account validity in the response and the user
+        is expired, we return the correct info.
+        """
+        user = self.register_user("someuser", "password")
+
+        async def is_expired(user_id: str) -> bool:
+            # We can't blindly say everyone is expired, otherwise the request to get the
+            # account status will fail.
+            return UserID.from_string(user_id).localpart == "someuser"
+
+        self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
+            is_expired
+        )
+
+        self._test_status(
+            users=[user],
+            expected_statuses={
+                user: {
+                    "exists": True,
+                    "deactivated": False,
+                    "org.matrix.expired": True,
+                },
+            },
+            expected_failures=[],
+        )
+
     def _test_status(
         self,
         users: Optional[List[str]],
diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 3818b7b14b..7b7d283bb6 100644
--- a/tests/rest/client/test_shared_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -14,7 +14,7 @@
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
-from synapse.rest.client import login, room, shared_rooms
+from synapse.rest.client import login, mutual_rooms, room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -22,16 +22,16 @@ from tests import unittest
 from tests.server import FakeChannel
 
 
-class UserSharedRoomsTest(unittest.HomeserverTestCase):
+class UserMutualRoomsTest(unittest.HomeserverTestCase):
     """
-    Tests the UserSharedRoomsServlet.
+    Tests the UserMutualRoomsServlet.
     """
 
     servlets = [
         login.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
-        shared_rooms.register_servlets,
+        mutual_rooms.register_servlets,
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -43,10 +43,10 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.handler = hs.get_user_directory_handler()
 
-    def _get_shared_rooms(self, token: str, other_user: str) -> FakeChannel:
+    def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
         return self.make_request(
             "GET",
-            "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+            "/_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms/%s"
             % other_user,
             access_token=token,
         )
@@ -56,14 +56,14 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         A room should show up in the shared list of rooms between two users
         if it is public.
         """
-        self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True)
+        self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=True)
 
     def test_shared_room_list_private(self) -> None:
         """
         A room should show up in the shared list of rooms between two users
         if it is private.
         """
-        self._check_shared_rooms_with(
+        self._check_mutual_rooms_with(
             room_one_is_public=False, room_two_is_public=False
         )
 
@@ -72,9 +72,9 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         The shared room list between two users should contain both public and private
         rooms.
         """
-        self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False)
+        self._check_mutual_rooms_with(room_one_is_public=True, room_two_is_public=False)
 
-    def _check_shared_rooms_with(
+    def _check_mutual_rooms_with(
         self, room_one_is_public: bool, room_two_is_public: bool
     ) -> None:
         """Checks that shared public or private rooms between two users appear in
@@ -94,7 +94,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
 
         # Check shared rooms from user1's perspective.
         # We should see the one room in common
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 1)
         self.assertEqual(channel.json_body["joined"][0], room_id_one)
@@ -107,7 +107,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room_id_two, user=u2, tok=u2_token)
 
         # Check shared rooms again. We should now see both rooms.
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 2)
         for room_id_id in channel.json_body["joined"]:
@@ -128,7 +128,7 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Assert user directory is not empty
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 1)
         self.assertEqual(channel.json_body["joined"][0], room)
@@ -136,11 +136,11 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase):
         self.helper.leave(room, user=u1, tok=u1_token)
 
         # Check user1's view of shared rooms with user2
-        channel = self._get_shared_rooms(u1_token, u2)
+        channel = self._get_mutual_rooms(u1_token, u2)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 0)
 
         # Check user2's view of shared rooms with user1
-        channel = self._get_shared_rooms(u2_token, u1)
+        channel = self._get_mutual_rooms(u2_token, u1)
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 709f851a38..fe97a0b3dd 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,17 +15,16 @@
 
 import itertools
 import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 from unittest.mock import patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
 from synapse.server import HomeServer
-from synapse.storage.relations import RelationPaginationToken
-from synapse.types import JsonDict, StreamToken
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -80,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
         content: Optional[dict] = None,
         access_token: Optional[str] = None,
         parent_id: Optional[str] = None,
+        expected_response_code: int = 200,
     ) -> FakeChannel:
         """Helper function to send a relation pointing at `self.parent_id`
 
@@ -116,16 +116,60 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
             content,
             access_token=access_token,
         )
+        self.assertEqual(expected_response_code, channel.code, channel.json_body)
         return channel
 
+    def _get_related_events(self) -> List[str]:
+        """
+        Requests /relations on the parent ID and returns a list of event IDs.
+        """
+        # Request the relations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        return [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+    def _get_bundled_aggregations(self) -> JsonDict:
+        """
+        Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists.
+        """
+        # Fetch the bundled aggregations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        return channel.json_body["unsigned"].get("m.relations", {})
+
+    def _get_aggregations(self) -> List[JsonDict]:
+        """Request /aggregations on the parent ID and includes the returned chunk."""
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        return channel.json_body["chunk"]
+
+    def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+        """
+        Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+        """
+        for event in events:
+            if event["event_id"] == self.parent_id:
+                return event
+
+        raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
 
 class RelationsTestCase(BaseRelationsTestCase):
     def test_send_relation(self) -> None:
         """Tests that sending a relation works."""
-
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
-
         event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -152,13 +196,13 @@ class RelationsTestCase(BaseRelationsTestCase):
 
     def test_deny_invalid_event(self) -> None:
         """Test that we deny relations on non-existant events"""
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.ANNOTATION,
             EventTypes.Message,
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
+            expected_response_code=400,
         )
-        self.assertEqual(400, channel.code, channel.json_body)
 
         # Unless that event is referenced from another event!
         self.get_success(
@@ -172,13 +216,12 @@ class RelationsTestCase(BaseRelationsTestCase):
                 desc="test_deny_invalid_event",
             )
         )
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.THREAD,
             EventTypes.Message,
             parent_id="foo",
             content={"body": "foo", "msgtype": "m.text"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
     def test_deny_invalid_room(self) -> None:
         """Test that we deny relations on non-existant events"""
@@ -188,18 +231,20 @@ class RelationsTestCase(BaseRelationsTestCase):
         parent_id = res["event_id"]
 
         # Attempt to send an annotation to that event.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+        self._send_relation(
+            RelationTypes.ANNOTATION,
+            "m.reaction",
+            parent_id=parent_id,
+            key="A",
+            expected_response_code=400,
         )
-        self.assertEqual(400, channel.code, channel.json_body)
 
     def test_deny_double_react(self) -> None:
         """Test that we deny relations on membership events"""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(400, channel.code, channel.json_body)
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400
+        )
 
     def test_deny_forked_thread(self) -> None:
         """It is invalid to start a thread off a thread."""
@@ -209,386 +254,24 @@ class RelationsTestCase(BaseRelationsTestCase):
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=self.parent_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         parent_id = channel.json_body["event_id"]
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.THREAD,
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo"},
             parent_id=parent_id,
-        )
-        self.assertEqual(400, channel.code, channel.json_body)
-
-    def test_basic_paginate_relations(self) -> None:
-        """Tests that calling pagination API correctly the latest relations."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-        first_annotation_id = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
-        second_annotation_id = channel.json_body["event_id"]
-
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        # We expect to get back a single pagination result, which is the latest
-        # full relation event we sent above.
-        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-        self.assert_dict(
-            {
-                "event_id": second_annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
-            channel.json_body["chunk"][0],
-        )
-
-        # We also expect to get the original event (the id of which is self.parent_id)
-        self.assertEqual(
-            channel.json_body["original_event"]["event_id"], self.parent_id
-        )
-
-        # Make sure next_batch has something in it that looks like it could be a
-        # valid token.
-        self.assertIsInstance(
-            channel.json_body.get("next_batch"), str, channel.json_body
-        )
-
-        # Request the relations again, but with a different direction.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        # We expect to get back a single pagination result, which is the earliest
-        # full relation event we sent above.
-        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-        self.assert_dict(
-            {
-                "event_id": first_annotation_id,
-                "sender": self.user_id,
-                "type": "m.reaction",
-            },
-            channel.json_body["chunk"][0],
-        )
-
-    def _stream_token_to_relation_token(self, token: str) -> str:
-        """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
-        room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
-        return self.get_success(
-            RelationPaginationToken(
-                topological=room_key.topological, stream=room_key.stream
-            ).to_string(self.store)
-        )
-
-    def test_repeated_paginate_relations(self) -> None:
-        """Test that if we paginate using a limit and tokens then we get the
-        expected events.
-        """
-
-        expected_event_ids = []
-        for idx in range(10):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-            expected_event_ids.append(channel.json_body["event_id"])
-
-        prev_token = ""
-        found_event_ids: List[str] = []
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
-        # Reset and try again, but convert the tokens to the legacy format.
-        prev_token = ""
-        found_event_ids = []
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
-    def test_pagination_from_sync_and_messages(self) -> None:
-        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
-        self.assertEqual(200, channel.code, channel.json_body)
-        annotation_id = channel.json_body["event_id"]
-        # Send an event after the relation events.
-        self.helper.send(self.room, body="Latest event", tok=self.user_token)
-
-        # Request /sync, limiting it such that only the latest event is returned
-        # (and not the relation).
-        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
-        channel = self.make_request(
-            "GET", f"/sync?filter={filter}", access_token=self.user_token
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        sync_prev_batch = room_timeline["prev_batch"]
-        self.assertIsNotNone(sync_prev_batch)
-        # Ensure the relation event is not in the batch returned from /sync.
-        self.assertNotIn(
-            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
-        )
-
-        # Request /messages, limiting it such that only the latest event is
-        # returned (and not the relation).
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/messages?dir=b&limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        messages_end = channel.json_body["end"]
-        self.assertIsNotNone(messages_end)
-        # Ensure the relation event is not in the chunk returned from /messages.
-        self.assertNotIn(
-            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            expected_response_code=400,
         )
 
-        # Request /relations with the pagination tokens received from both the
-        # /sync and /messages responses above, in turn.
-        #
-        # This is a tiny bit silly since the client wouldn't know the parent ID
-        # from the requests above; consider the parent ID to be known from a
-        # previous /sync.
-        for from_token in (sync_prev_batch, messages_end):
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            # The relation should be in the returned chunk.
-            self.assertIn(
-                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
-            )
-
-    def test_aggregation_pagination_groups(self) -> None:
-        """Test that we can paginate annotation groups correctly."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
-        for key in itertools.chain.from_iterable(
-            itertools.repeat(key, num) for key, num in sent_groups.items()
-        ):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key=key,
-                access_token=access_tokens[idx],
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            idx += 1
-            idx %= len(access_tokens)
-
-        prev_token: Optional[str] = None
-        found_groups: Dict[str, int] = {}
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            for groups in channel.json_body["chunk"]:
-                # We only expect reactions
-                self.assertEqual(groups["type"], "m.reaction", channel.json_body)
-
-                # We should only see each key once
-                self.assertNotIn(groups["key"], found_groups, channel.json_body)
-
-                found_groups[groups["key"]] = groups["count"]
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        self.assertEqual(sent_groups, found_groups)
-
-    def test_aggregation_pagination_within_group(self) -> None:
-        """Test that we can paginate within an annotation group."""
-
-        # We need to create ten separate users to send each reaction.
-        access_tokens = [self.user_token, self.user2_token]
-        idx = 0
-        while len(access_tokens) < 10:
-            user_id, token = self._create_user("test" + str(idx))
-            idx += 1
-
-            self.helper.join(self.room, user=user_id, tok=token)
-            access_tokens.append(token)
-
-        idx = 0
-        expected_event_ids = []
-        for _ in range(10):
-            channel = self._send_relation(
-                RelationTypes.ANNOTATION,
-                "m.reaction",
-                key="👍",
-                access_token=access_tokens[idx],
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-            expected_event_ids.append(channel.json_body["event_id"])
-
-            idx += 1
-
-        # Also send a different type of reaction so that we test we don't see it
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        prev_token = ""
-        found_event_ids: List[str] = []
-        encoded_key = urllib.parse.quote_plus("👍".encode())
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + prev_token
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}"
-                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
-                f"/m.reaction/{encoded_key}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
-        # Reset and try again, but convert the tokens to the legacy format.
-        prev_token = ""
-        found_event_ids = []
-        for _ in range(20):
-            from_token = ""
-            if prev_token:
-                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
-
-            channel = self.make_request(
-                "GET",
-                f"/_matrix/client/unstable/rooms/{self.room}"
-                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
-                f"/m.reaction/{encoded_key}?limit=1{from_token}",
-                access_token=self.user_token,
-            )
-            self.assertEqual(200, channel.code, channel.json_body)
-
-            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
-
-            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
-
-            next_batch = channel.json_body.get("next_batch")
-
-            self.assertNotEqual(prev_token, next_batch)
-            prev_token = next_batch
-
-            if not prev_token:
-                break
-
-        # We paginated backwards, so reverse
-        found_event_ids.reverse()
-        self.assertEqual(found_event_ids, expected_event_ids)
-
     def test_aggregation(self) -> None:
         """Test that annotations get correctly aggregated."""
 
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
 
         channel = self.make_request(
             "GET",
@@ -618,220 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(400, channel.code, channel.json_body)
 
-    @unittest.override_config(
-        {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
-    )
-    def test_bundled_aggregations(self) -> None:
-        """
-        Test that annotations, references, and threads get correctly bundled.
-
-        Note that this doesn't test against /relations since only thread relations
-        get bundled via that API. See test_aggregation_get_event_for_thread.
-
-        See test_edit for a similar test for edits.
-        """
-        # Setup by sending a variety of relations.
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        reply_1 = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        reply_2 = channel.json_body["event_id"]
-
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        thread_2 = channel.json_body["event_id"]
-
-        def assert_bundle(event_json: JsonDict) -> None:
-            """Assert the expected values of the bundled aggregations."""
-            relations_dict = event_json["unsigned"].get("m.relations")
-
-            # Ensure the fields are as expected.
-            self.assertCountEqual(
-                relations_dict.keys(),
-                (
-                    RelationTypes.ANNOTATION,
-                    RelationTypes.REFERENCE,
-                    RelationTypes.THREAD,
-                ),
-            )
-
-            # Check the values of each field.
-            self.assertEqual(
-                {
-                    "chunk": [
-                        {"type": "m.reaction", "key": "a", "count": 2},
-                        {"type": "m.reaction", "key": "b", "count": 1},
-                    ]
-                },
-                relations_dict[RelationTypes.ANNOTATION],
-            )
-
-            self.assertEqual(
-                {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
-                relations_dict[RelationTypes.REFERENCE],
-            )
-
-            self.assertEqual(
-                2,
-                relations_dict[RelationTypes.THREAD].get("count"),
-            )
-            self.assertTrue(
-                relations_dict[RelationTypes.THREAD].get("current_user_participated")
-            )
-            # The latest thread event has some fields that don't matter.
-            self.assert_dict(
-                {
-                    "content": {
-                        "m.relates_to": {
-                            "event_id": self.parent_id,
-                            "rel_type": RelationTypes.THREAD,
-                        }
-                    },
-                    "event_id": thread_2,
-                    "room_id": self.room,
-                    "sender": self.user_id,
-                    "type": "m.room.test",
-                    "user_id": self.user_id,
-                },
-                relations_dict[RelationTypes.THREAD].get("latest_event"),
-            )
-
-        # Request the event directly.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body)
-
-        # Request the room messages.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/messages?dir=b",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
-        # Request the room context.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/context/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body["event"])
-
-        # Request sync.
-        channel = self.make_request("GET", "/sync", access_token=self.user_token)
-        self.assertEqual(200, channel.code, channel.json_body)
-        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        self.assertTrue(room_timeline["limited"])
-        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
-        # Request search.
-        channel = self.make_request(
-            "POST",
-            "/search",
-            # Search term matches the parent message.
-            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        chunk = [
-            result["result"]
-            for result in channel.json_body["search_categories"]["room_events"][
-                "results"
-            ]
-        ]
-        assert_bundle(self._find_event_in_chunk(chunk))
-
-    def test_aggregation_get_event_for_annotation(self) -> None:
-        """Test that annotations do not get bundled aggregations included
-        when directly requested.
-        """
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
-        annotation_id = channel.json_body["event_id"]
-
-        # Annotate the annotation.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{annotation_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
-    def test_aggregation_get_event_for_thread(self) -> None:
-        """Test that threads get bundled aggregations included when directly requested."""
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
-        thread_id = channel.json_body["event_id"]
-
-        # Annotate the annotation.
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/event/{thread_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["unsigned"].get("m.relations"),
-            {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
-            },
-        )
-
-        # It should also be included when the entire thread is requested.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(len(channel.json_body["chunk"]), 1)
-
-        thread_message = channel.json_body["chunk"][0]
-        self.assertEqual(
-            thread_message["unsigned"].get("m.relations"),
-            {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
-            },
-        )
-
-    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_ignore_invalid_room(self) -> None:
         """Test that we ignore invalid relations over federation."""
         # Create another room and send a message in it.
@@ -953,8 +422,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
         def assert_bundle(event_json: JsonDict) -> None:
@@ -1030,7 +497,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         shouldn't be allowed, are correctly handled.
         """
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={
@@ -1039,7 +506,6 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
         channel = self._send_relation(
@@ -1047,11 +513,9 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message.WRONG_TYPE",
             content={
@@ -1060,7 +524,6 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         channel = self.make_request(
             "GET",
@@ -1091,7 +554,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A reply!"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         reply = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
@@ -1101,8 +563,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=reply,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
-
         edit_event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -1138,7 +598,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
-    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_edit_thread(self) -> None:
         """Test that editing a thread works."""
 
@@ -1148,17 +607,15 @@ class RelationsTestCase(BaseRelationsTestCase):
             "m.room.message",
             content={"msgtype": "m.text", "body": "A threaded reply!"},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         threaded_event_id = channel.json_body["event_id"]
 
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=threaded_event_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Fetch the thread root, to get the bundled aggregation for the thread.
         channel = self.make_request(
@@ -1190,11 +647,10 @@ class RelationsTestCase(BaseRelationsTestCase):
                 "m.new_content": new_body,
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
         edit_event_id = channel.json_body["event_id"]
 
         # Edit the edit event.
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             content={
@@ -1204,7 +660,6 @@ class RelationsTestCase(BaseRelationsTestCase):
             },
             parent_id=edit_event_id,
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Request the original event.
         channel = self.make_request(
@@ -1231,7 +686,6 @@ class RelationsTestCase(BaseRelationsTestCase):
     def test_unknown_relations(self) -> None:
         """Unknown relations should be accepted."""
         channel = self._send_relation("m.relation.test", "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
         event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -1272,28 +726,15 @@ class RelationsTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
         self.assertEqual(channel.json_body["chunk"], [])
 
-    def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
-        """
-        Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
-        """
-        for event in events:
-            if event["event_id"] == self.parent_id:
-                return event
-
-        raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
     def test_background_update(self) -> None:
         """Test the event_arbitrary_relations background update."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_good = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
-        self.assertEqual(200, channel.code, channel.json_body)
         annotation_event_id_bad = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
-        self.assertEqual(200, channel.code, channel.json_body)
         thread_event_id = channel.json_body["event_id"]
 
         # Clean-up the table as if the inserts did not happen during event creation.
@@ -1345,8 +786,638 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
 
 
+class RelationPaginationTestCase(BaseRelationsTestCase):
+    def test_basic_paginate_relations(self) -> None:
+        """Tests that calling pagination API correctly the latest relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        first_annotation_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+        second_annotation_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the latest
+        # full relation event we sent above.
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": second_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
+        # We also expect to get the original event (the id of which is self.parent_id)
+        self.assertEqual(
+            channel.json_body["original_event"]["event_id"], self.parent_id
+        )
+
+        # Make sure next_batch has something in it that looks like it could be a
+        # valid token.
+        self.assertIsInstance(
+            channel.json_body.get("next_batch"), str, channel.json_body
+        )
+
+        # Request the relations again, but with a different direction.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations"
+            f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the earliest
+        # full relation event we sent above.
+        self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": first_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
+    def test_repeated_paginate_relations(self) -> None:
+        """Test that if we paginate using a limit and tokens then we get the
+        expected events.
+        """
+
+        expected_event_ids = []
+        for idx in range(10):
+            channel = self._send_relation(
+                RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+            )
+            expected_event_ids.append(channel.json_body["event_id"])
+
+        prev_token = ""
+        found_event_ids: List[str] = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEqual(found_event_ids, expected_event_ids)
+
+    def test_pagination_from_sync_and_messages(self) -> None:
+        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+        annotation_id = channel.json_body["event_id"]
+        # Send an event after the relation events.
+        self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+        # Request /sync, limiting it such that only the latest event is returned
+        # (and not the relation).
+        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        sync_prev_batch = room_timeline["prev_batch"]
+        self.assertIsNotNone(sync_prev_batch)
+        # Ensure the relation event is not in the batch returned from /sync.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+        )
+
+        # Request /messages, limiting it such that only the latest event is
+        # returned (and not the relation).
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b&limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        messages_end = channel.json_body["end"]
+        self.assertIsNotNone(messages_end)
+        # Ensure the relation event is not in the chunk returned from /messages.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+        )
+
+        # Request /relations with the pagination tokens received from both the
+        # /sync and /messages responses above, in turn.
+        #
+        # This is a tiny bit silly since the client wouldn't know the parent ID
+        # from the requests above; consider the parent ID to be known from a
+        # previous /sync.
+        for from_token in (sync_prev_batch, messages_end):
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            # The relation should be in the returned chunk.
+            self.assertIn(
+                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            )
+
+    def test_aggregation_pagination_groups(self) -> None:
+        """Test that we can paginate annotation groups correctly."""
+
+        # We need to create ten separate users to send each reaction.
+        access_tokens = [self.user_token, self.user2_token]
+        idx = 0
+        while len(access_tokens) < 10:
+            user_id, token = self._create_user("test" + str(idx))
+            idx += 1
+
+            self.helper.join(self.room, user=user_id, tok=token)
+            access_tokens.append(token)
+
+        idx = 0
+        sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
+        for key in itertools.chain.from_iterable(
+            itertools.repeat(key, num) for key, num in sent_groups.items()
+        ):
+            self._send_relation(
+                RelationTypes.ANNOTATION,
+                "m.reaction",
+                key=key,
+                access_token=access_tokens[idx],
+            )
+
+            idx += 1
+            idx %= len(access_tokens)
+
+        prev_token: Optional[str] = None
+        found_groups: Dict[str, int] = {}
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            for groups in channel.json_body["chunk"]:
+                # We only expect reactions
+                self.assertEqual(groups["type"], "m.reaction", channel.json_body)
+
+                # We should only see each key once
+                self.assertNotIn(groups["key"], found_groups, channel.json_body)
+
+                found_groups[groups["key"]] = groups["count"]
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        self.assertEqual(sent_groups, found_groups)
+
+    def test_aggregation_pagination_within_group(self) -> None:
+        """Test that we can paginate within an annotation group."""
+
+        # We need to create ten separate users to send each reaction.
+        access_tokens = [self.user_token, self.user2_token]
+        idx = 0
+        while len(access_tokens) < 10:
+            user_id, token = self._create_user("test" + str(idx))
+            idx += 1
+
+            self.helper.join(self.room, user=user_id, tok=token)
+            access_tokens.append(token)
+
+        idx = 0
+        expected_event_ids = []
+        for _ in range(10):
+            channel = self._send_relation(
+                RelationTypes.ANNOTATION,
+                "m.reaction",
+                key="👍",
+                access_token=access_tokens[idx],
+            )
+            expected_event_ids.append(channel.json_body["event_id"])
+
+            idx += 1
+
+        # Also send a different type of reaction so that we test we don't see it
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+
+        prev_token = ""
+        found_event_ids: List[str] = []
+        encoded_key = urllib.parse.quote_plus("👍".encode())
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEqual(found_event_ids, expected_event_ids)
+
+
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+    """
+    See RelationsTestCase.test_edit for a similar test for edits.
+
+    Note that this doesn't test against /relations since only thread relations
+    get bundled via that API. See test_aggregation_get_event_for_thread.
+    """
+
+    def _test_bundled_aggregations(
+        self,
+        relation_type: str,
+        assertion_callable: Callable[[JsonDict], None],
+        expected_db_txn_for_event: int,
+    ) -> None:
+        """
+        Makes requests to various endpoints which should include bundled aggregations
+        and then calls an assertion function on the bundled aggregations.
+
+        Args:
+            relation_type: The field to search for in the `m.relations` field in unsigned.
+            assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+                for relation-specific assertions.
+            expected_db_txn_for_event: The number of database transactions which
+                are expected for a call to /event/.
+        """
+
+        def assert_bundle(event_json: JsonDict) -> None:
+            """Assert the expected values of the bundled aggregations."""
+            relations_dict = event_json["unsigned"].get("m.relations")
+
+            # Ensure the fields are as expected.
+            self.assertCountEqual(relations_dict.keys(), (relation_type,))
+            assertion_callable(relations_dict[relation_type])
+
+        # Request the event directly.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(channel.json_body)
+        assert channel.resource_usage is not None
+        self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+        # Request the room messages.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+        # Request the room context.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/context/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        assert_bundle(channel.json_body["event"])
+
+        # Request sync.
+        filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        self.assertTrue(room_timeline["limited"])
+        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+        # Request search.
+        channel = self.make_request(
+            "POST",
+            "/search",
+            # Search term matches the parent message.
+            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        chunk = [
+            result["result"]
+            for result in channel.json_body["search_categories"]["room_events"][
+                "results"
+            ]
+        ]
+        assert_bundle(self._find_event_in_chunk(chunk))
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_annotation(self) -> None:
+        """
+        Test that annotations get correctly bundled.
+        """
+        # Setup by sending a variety of relations.
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+        )
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(
+                {
+                    "chunk": [
+                        {"type": "m.reaction", "key": "a", "count": 2},
+                        {"type": "m.reaction", "key": "b", "count": 1},
+                    ]
+                },
+                bundled_aggregations,
+            )
+
+        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_reference(self) -> None:
+        """
+        Test that references get correctly bundled.
+        """
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        reply_1 = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        reply_2 = channel.json_body["event_id"]
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(
+                {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+                bundled_aggregations,
+            )
+
+        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+    def test_thread(self) -> None:
+        """
+        Test that threads get correctly bundled.
+        """
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_2 = channel.json_body["event_id"]
+
+        def assert_annotations(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(2, bundled_aggregations.get("count"))
+            self.assertTrue(bundled_aggregations.get("current_user_participated"))
+            # The latest thread event has some fields that don't matter.
+            self.assert_dict(
+                {
+                    "content": {
+                        "m.relates_to": {
+                            "event_id": self.parent_id,
+                            "rel_type": RelationTypes.THREAD,
+                        }
+                    },
+                    "event_id": thread_2,
+                    "sender": self.user_id,
+                    "type": "m.room.test",
+                },
+                bundled_aggregations.get("latest_event"),
+            )
+
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+    def test_aggregation_get_event_for_annotation(self) -> None:
+        """Test that annotations do not get bundled aggregations included
+        when directly requested.
+        """
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        annotation_id = channel.json_body["event_id"]
+
+        # Annotate the annotation.
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+        )
+
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{annotation_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+    def test_aggregation_get_event_for_thread(self) -> None:
+        """Test that threads get bundled aggregations included when directly requested."""
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_id = channel.json_body["event_id"]
+
+        # Annotate the annotation.
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+        )
+
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{thread_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(
+            channel.json_body["unsigned"].get("m.relations"),
+            {
+                RelationTypes.ANNOTATION: {
+                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+                },
+            },
+        )
+
+        # It should also be included when the entire thread is requested.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+        thread_message = channel.json_body["chunk"][0]
+        self.assertEqual(
+            thread_message["unsigned"].get("m.relations"),
+            {
+                RelationTypes.ANNOTATION: {
+                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+                },
+            },
+        )
+
+    def test_bundled_aggregations_with_filter(self) -> None:
+        """
+        If "unsigned" is an omitted field (due to filtering), adding the bundled
+        aggregations should not break.
+
+        Note that the spec allows for a server to return additional fields beyond
+        what is specified.
+        """
+        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+
+        # Note that the sync filter does not include "unsigned" as a field.
+        filter = urllib.parse.quote_plus(
+            b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
+        )
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+
+        # Ensure the timeline is limited, find the parent event.
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        self.assertTrue(room_timeline["limited"])
+        parent_event = self._find_event_in_chunk(room_timeline["events"])
+
+        # Ensure there's bundled aggregations on it.
+        self.assertIn("unsigned", parent_event)
+        self.assertIn("m.relations", parent_event["unsigned"])
+
+
+class RelationIgnoredUserTestCase(BaseRelationsTestCase):
+    """Relations sent from an ignored user should be ignored."""
+
+    def _test_ignored_user(
+        self, allowed_event_ids: List[str], ignored_event_ids: List[str]
+    ) -> None:
+        """
+        Fetch the relations and ensure they're all there, then ignore user2, and
+        repeat.
+        """
+        # Get the relations.
+        event_ids = self._get_related_events()
+        self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids)
+
+        # Ignore user2 and re-do the requests.
+        self.get_success(
+            self.store.add_account_data_for_user(
+                self.user_id,
+                AccountDataTypes.IGNORED_USER_LIST,
+                {"ignored_users": {self.user2_id: {}}},
+            )
+        )
+
+        # Get the relations.
+        event_ids = self._get_related_events()
+        self.assertCountEqual(event_ids, allowed_event_ids)
+
+    def test_annotation(self) -> None:
+        """Annotations should ignore"""
+        # Send 2 from us, 2 from the to be ignored user.
+        allowed_event_ids = []
+        ignored_event_ids = []
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+        allowed_event_ids.append(channel.json_body["event_id"])
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
+        allowed_event_ids.append(channel.json_body["event_id"])
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION,
+            "m.reaction",
+            key="a",
+            access_token=self.user2_token,
+        )
+        ignored_event_ids.append(channel.json_body["event_id"])
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION,
+            "m.reaction",
+            key="c",
+            access_token=self.user2_token,
+        )
+        ignored_event_ids.append(channel.json_body["event_id"])
+
+        self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+    def test_reference(self) -> None:
+        """Annotations should ignore"""
+        channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+        allowed_event_ids = [channel.json_body["event_id"]]
+
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "m.room.test", access_token=self.user2_token
+        )
+        ignored_event_ids = [channel.json_body["event_id"]]
+
+        self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+    def test_thread(self) -> None:
+        """Annotations should ignore"""
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        allowed_event_ids = [channel.json_body["event_id"]]
+
+        channel = self._send_relation(
+            RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+        )
+        ignored_event_ids = [channel.json_body["event_id"]]
+
+        self._test_ignored_user(allowed_event_ids, ignored_event_ids)
+
+
 class RelationRedactionTestCase(BaseRelationsTestCase):
-    """Test the behaviour of relations when the parent or child event is redacted."""
+    """
+    Test the behaviour of relations when the parent or child event is redacted.
+
+    The behaviour of each relation type is subtly different which causes the tests
+    to be a bit repetitive, they follow a naming scheme of:
+
+        test_redact_(relation|parent)_{relation_type}
+
+    The first bit of "relation" means that the event with the relation defined
+    on it (the child event) is to be redacted. A "parent" means that the target
+    of the relation (the parent event) is to be redacted.
+
+    The relation_type describes which type of relation is under test (i.e. it is
+    related to the value of rel_type in the event content).
+    """
 
     def _redact(self, event_id: str) -> None:
         channel = self.make_request(
@@ -1358,40 +1429,116 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
 
     def test_redact_relation_annotation(self) -> None:
-        """Test that annotations of an event are properly handled after the
+        """
+        Test that annotations of an event are properly handled after the
         annotation is redacted.
+
+        The redacted relation should not be included in bundled aggregations or
+        the response to relations.
         """
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self.assertEqual(200, channel.code, channel.json_body)
         to_redact_event_id = channel.json_body["event_id"]
 
         channel = self._send_relation(
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
-        self.assertEqual(200, channel.code, channel.json_body)
+        unredacted_event_id = channel.json_body["event_id"]
+
+        # Both relations should exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
+        )
+
+        # Both relations appear in the aggregation.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
 
         # Redact one of the reactions.
         self._redact(to_redact_event_id)
 
-        # Ensure that the aggregations are correct.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
-            access_token=self.user_token,
+        # The unredacted relation should still exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [unredacted_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
+        # The unredacted aggregation should still exist.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
+
+    def test_redact_relation_thread(self) -> None:
+        """
+        Test that thread replies are properly handled after the thread reply redacted.
+
+        The redacted event should not be included in bundled aggregations or
+        the response to relations.
+        """
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 1", "msgtype": "m.text"},
+        )
+        unredacted_event_id = channel.json_body["event_id"]
+
+        # Note that the *last* event in the thread is redacted, as that gets
+        # included in the bundled aggregation.
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 2", "msgtype": "m.text"},
+        )
+        to_redact_event_id = channel.json_body["event_id"]
+
+        # Both relations exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
+        self.assertDictContainsSubset(
+            {
+                "count": 2,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        # And the latest event returned is the event that will be redacted.
+        self.assertEqual(
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            to_redact_event_id,
+        )
+
+        # Redact one of the reactions.
+        self._redact(to_redact_event_id)
+
+        # The unredacted relation should still exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [unredacted_event_id])
+        self.assertDictContainsSubset(
+            {
+                "count": 1,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        # And the latest event is now the unredacted event.
         self.assertEqual(
-            channel.json_body,
-            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            unredacted_event_id,
         )
 
-    def test_redact_relation_edit(self) -> None:
+    def test_redact_parent_edit(self) -> None:
         """Test that edits of an event are redacted when the original event
         is redacted.
         """
         # Add a relation
-        channel = self._send_relation(
+        self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
             parent_id=self.parent_id,
@@ -1401,54 +1548,83 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
                 "m.new_content": {"msgtype": "m.text", "body": "First edit"},
             },
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
         # Check the relation is returned
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}/m.replace/m.room.message",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(len(channel.json_body["chunk"]), 1)
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEqual(len(event_ids), 1)
+        self.assertIn(RelationTypes.REPLACE, relations)
 
         # Redact the original event
         self._redact(self.parent_id)
 
-        # Try to check for remaining m.replace relations
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}/m.replace/m.room.message",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
+        # The relations are not returned.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEqual(len(event_ids), 0)
+        self.assertEqual(relations, {})
 
-        # Check that no relations are returned
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
-
-    def test_redact_parent(self) -> None:
-        """Test that annotations of an event are redacted when the original event
+    def test_redact_parent_annotation(self) -> None:
+        """Test that annotations of an event are viewable when the original event
         is redacted.
         """
         # Add a relation
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
-        self.assertEqual(200, channel.code, channel.json_body)
+        related_event_id = channel.json_body["event_id"]
+
+        # The relations should exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEqual(len(event_ids), 1)
+        self.assertIn(RelationTypes.ANNOTATION, relations)
+
+        # The aggregation should exist.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
 
         # Redact the original event.
         self._redact(self.parent_id)
 
-        # Check that aggregations returns zero
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
-            access_token=self.user_token,
+        # The relations are returned.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(event_ids, [related_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
         )
-        self.assertEqual(200, channel.code, channel.json_body)
 
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
+        # There's nothing to aggregate.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
+
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_redact_parent_thread(self) -> None:
+        """
+        Test that thread replies are still available when the root event is redacted.
+        """
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 1", "msgtype": "m.text"},
+        )
+        related_event_id = channel.json_body["event_id"]
+
+        # Redact one of the reactions.
+        self._redact(self.parent_id)
+
+        # The unredacted relation should still exist.
+        event_ids = self._get_related_events()
+        relations = self._get_bundled_aggregations()
+        self.assertEquals(len(event_ids), 1)
+        self.assertDictContainsSubset(
+            {
+                "count": 1,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        self.assertEqual(
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            related_event_id,
+        )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index f3bf8d0934..7b8fe6d025 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -24,6 +24,7 @@ from synapse.util import Clock
 from synapse.visibility import filter_events_for_client
 
 from tests import unittest
+from tests.unittest import override_config
 
 one_hour_ms = 3600000
 one_day_ms = one_hour_ms * 24
@@ -38,7 +39,10 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
-        config["retention"] = {
+
+        # merge this default retention config with anything that was specified in
+        # @override_config
+        retention_config = {
             "enabled": True,
             "default_policy": {
                 "min_lifetime": one_day_ms,
@@ -47,6 +51,8 @@ class RetentionTestCase(unittest.HomeserverTestCase):
             "allowed_lifetime_min": one_day_ms,
             "allowed_lifetime_max": one_day_ms * 3,
         }
+        retention_config.update(config.get("retention", {}))
+        config["retention"] = retention_config
 
         self.hs = self.setup_test_homeserver(config=config)
 
@@ -115,22 +121,20 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         self._test_retention_event_purged(room_id, one_day_ms * 2)
 
+    @override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
     def test_visibility(self) -> None:
         """Tests that synapse.visibility.filter_events_for_client correctly filters out
-        outdated events
+        outdated events, even if the purge job hasn't got to them yet.
+
+        We do this by setting a very long time between purge jobs.
         """
         store = self.hs.get_datastores().main
         storage = self.hs.get_storage()
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
-        events = []
 
         # Send a first event, which should be filtered out at the end of the test.
         resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
-
-        # Get the event from the store so that we end up with a FrozenEvent that we can
-        # give to filter_events_for_client. We need to do this now because the event won't
-        # be in the database anymore after it has expired.
-        events.append(self.get_success(store.get_event(resp.get("event_id"))))
+        first_event_id = resp.get("event_id")
 
         # Advance the time by 2 days. We're using the default retention policy, therefore
         # after this the first event will still be valid.
@@ -138,16 +142,17 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         # Send another event, which shouldn't get filtered out.
         resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
-
         valid_event_id = resp.get("event_id")
 
-        events.append(self.get_success(store.get_event(valid_event_id)))
-
         # Advance the time by another 2 days. After this, the first event should be
         # outdated but not the second one.
         self.reactor.advance(one_day_ms * 2 / 1000)
 
-        # Run filter_events_for_client with our list of FrozenEvents.
+        # Fetch the events, and run filter_events_for_client on them
+        events = self.get_success(
+            store.get_events_as_list([first_event_id, valid_event_id])
+        )
+        self.assertEqual(2, len(events), "events retrieved from database")
         filtered_events = self.get_success(
             filter_events_for_client(storage, self.user_id, events)
         )
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 37866ee330..3a9617d6da 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     def test_filter_relation_senders(self) -> None:
         # Messages which second user reacted to.
-        filter = {"io.element.relation_senders": [self.second_user_id]}
+        filter = {"related_by_senders": [self.second_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_1)
 
         # Messages which third user reacted to.
-        filter = {"io.element.relation_senders": [self.third_user_id]}
+        filter = {"related_by_senders": [self.third_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_2)
 
         # Messages which either user reacted to.
-        filter = {
-            "io.element.relation_senders": [self.second_user_id, self.third_user_id]
-        }
+        filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 2, chunk)
         self.assertCountEqual(
@@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     def test_filter_relation_type(self) -> None:
         # Messages which have annotations.
-        filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+        filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_1)
 
         # Messages which have references.
-        filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+        filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0]["event_id"], self.event_id_2)
 
         # Messages which have either annotations or references.
         filter = {
-            "io.element.relation_types": [
+            "related_by_rel_types": [
                 RelationTypes.ANNOTATION,
                 RelationTypes.REFERENCE,
             ]
@@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
     def test_filter_relation_senders_and_type(self) -> None:
         # Messages which second user reacted to.
         filter = {
-            "io.element.relation_senders": [self.second_user_id],
-            "io.element.relation_types": [RelationTypes.ANNOTATION],
+            "related_by_senders": [self.second_user_id],
+            "related_by_rel_types": [RelationTypes.ANNOTATION],
         }
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 58f1ea11b7..e7de67e3a3 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(args[0], user_id)
         self.assertFalse(args[1])
         self.assertTrue(args[2])
+
+    def test_check_can_deactivate_user(self) -> None:
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation.
+        """
+        # Register a mocked callback.
+        deactivation_mock = Mock(return_value=make_awaitable(False))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._check_can_deactivate_user_callbacks.append(
+            deactivation_mock,
+        )
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+        tok = self.login("altan", "password")
+
+        # Deactivate that user.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/account/deactivate",
+            {
+                "auth": {
+                    "type": LoginType.PASSWORD,
+                    "password": "password",
+                    "identifier": {
+                        "type": "m.id.user",
+                        "user": user_id,
+                    },
+                },
+                "erase": True,
+            },
+            access_token=tok,
+        )
+
+        # Check that the deactivation was blocked
+        self.assertEqual(channel.code, 403, channel.json_body)
+
+        # Check that the mock was called once.
+        deactivation_mock.assert_called_once()
+        args = deactivation_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID
+        self.assertEqual(args[0], user_id)
+
+        # Check that the request was not made by an admin
+        self.assertEqual(args[1], False)
+
+    def test_check_can_deactivate_user_admin(self) -> None:
+        """Tests that the on_user_deactivation_status_changed module callback is called
+        correctly when processing a user's deactivation triggered by a server admin.
+        """
+        # Register a mocked callback.
+        deactivation_mock = Mock(return_value=make_awaitable(False))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._check_can_deactivate_user_callbacks.append(
+            deactivation_mock,
+        )
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Register a user that we'll deactivate.
+        user_id = self.register_user("altan", "password")
+
+        # Deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {"deactivated": True},
+            access_token=admin_tok,
+        )
+
+        # Check that the deactivation was blocked
+        self.assertEqual(channel.code, 403, channel.json_body)
+
+        # Check that the mock was called once.
+        deactivation_mock.assert_called_once()
+        args = deactivation_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID
+        self.assertEqual(args[0], user_id)
+
+        # Check that the mock was made by an admin
+        self.assertEqual(args[1], True)
+
+    def test_check_can_shutdown_room(self) -> None:
+        """Tests that the check_can_shutdown_room module callback is called
+        correctly when processing an admin's shutdown room request.
+        """
+        # Register a mocked callback.
+        shutdown_mock = Mock(return_value=make_awaitable(False))
+        third_party_rules = self.hs.get_third_party_event_rules()
+        third_party_rules._check_can_shutdown_room_callbacks.append(
+            shutdown_mock,
+        )
+
+        # Register an admin user.
+        admin_user_id = self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Shutdown the room.
+        channel = self.make_request(
+            "DELETE",
+            "/_synapse/admin/v2/rooms/%s" % self.room_id,
+            {},
+            access_token=admin_tok,
+        )
+
+        # Check that the shutdown was blocked
+        self.assertEqual(channel.code, 403, channel.json_body)
+
+        # Check that the mock was called once.
+        shutdown_mock.assert_called_once()
+        args = shutdown_mock.call_args[0]
+
+        # Check that the mock was called with the right user ID
+        self.assertEqual(args[0], admin_user_id)
+
+        # Check that the mock was called with the right room ID
+        self.assertEqual(args[1], self.room_id)
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3b5747cb12..8d8251b2ac 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,3 +1,18 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
 from unittest.mock import Mock, call
 
 from twisted.internet import defer, reactor
@@ -11,14 +26,14 @@ from tests.utils import MockClock
 
 
 class HttpTransactionCacheTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock = MockClock()
         self.hs = Mock()
         self.hs.get_clock = Mock(return_value=self.clock)
         self.hs.get_auth = Mock()
         self.cache = HttpTransactionCache(self.hs)
 
-        self.mock_http_response = (200, "GOOD JOB!")
+        self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
         self.mock_key = "foo"
 
     @defer.inlineCallbacks
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 4672a68596..978c252f84 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -13,19 +13,24 @@
 # limitations under the License.
 import urllib.parse
 from io import BytesIO, StringIO
+from typing import Any, Dict, Optional, Union
 from unittest.mock import Mock
 
 import signedjson.key
 from canonicaljson import encode_canonical_json
-from nacl.signing import SigningKey
 from signedjson.sign import sign_json
+from signedjson.types import SigningKey
 
-from twisted.web.resource import NoResource
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import NoResource, Resource
 
 from synapse.crypto.keyring import PerspectivesKeyFetcher
 from synapse.http.site import SynapseRequest
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.server import HomeServer
 from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.stringutils import random_string
 
@@ -35,11 +40,11 @@ from tests.utils import default_config
 
 
 class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock()
         return self.setup_test_homeserver(federation_http_client=self.http_client)
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> Resource:
         return create_resource_tree(
             {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
         )
@@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect an outgoing GET request for the given key
         """
 
-        async def get_json(destination, path, ignore_backoff=False, **kwargs):
+        async def get_json(
+            destination: str,
+            path: str,
+            ignore_backoff: bool = False,
+            **kwargs: Any,
+        ) -> Union[JsonDict, list]:
             self.assertTrue(ignore_backoff)
             self.assertEqual(destination, server_name)
             key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
         Checks that the response is a 200 and returns the decoded json body.
         """
         channel = FakeChannel(self.site, self.reactor)
-        req = SynapseRequest(channel, self.site)
+        # channel is a `FakeChannel` but `HTTPChannel` is expected
+        req = SynapseRequest(channel, self.site)  # type: ignore[arg-type]
         req.content = BytesIO(b"")
         req.requestReceived(
             b"GET",
@@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
         resp = channel.json_body
         return resp
 
-    def test_get_key(self):
+    def test_get_key(self) -> None:
         """Fetch a remote key"""
         SERVER_NAME = "remote.server"
         testkey = signedjson.key.generate_signing_key("ver1")
@@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
         self.assertIn(SERVER_NAME, keys[0]["signatures"])
         self.assertIn(self.hs.hostname, keys[0]["signatures"])
 
-    def test_get_own_key(self):
+    def test_get_own_key(self) -> None:
         """Fetch our own key"""
         testkey = signedjson.key.generate_signing_key("ver1")
         self.expect_outgoing_key_request(self.hs.hostname, testkey)
@@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
     endpoint, to check that the two implementations are compatible.
     """
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
 
         # replace the signing key with our own
@@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # make a second homeserver, configured to use the first one as a key notary
         self.http_client2 = Mock()
         config = default_config(name="keyclient")
@@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
         # will be forwarded to hs1
-        async def post_json(destination, path, data):
+        async def post_json(
+            destination: str, path: str, data: Optional[JsonDict] = None
+        ) -> Union[JsonDict, list]:
             self.assertEqual(destination, self.hs.hostname)
             self.assertEqual(
                 path,
@@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             )
 
             channel = FakeChannel(self.site, self.reactor)
-            req = SynapseRequest(channel, self.site)
+            # channel is a `FakeChannel` but `HTTPChannel` is expected
+            req = SynapseRequest(channel, self.site)  # type: ignore[arg-type]
             req.content = BytesIO(encode_canonical_json(data))
 
             req.requestReceived(
@@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         self.http_client2.post_json.side_effect = post_json
 
-    def test_get_key(self):
+    def test_get_key(self) -> None:
         """Fetch a key belonging to a random server"""
         # make up a key to be fetched.
         testkey = signedjson.key.generate_signing_key("abc")
@@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             signedjson.key.encode_verify_key_base64(testkey.verify_key),
         )
 
-    def test_get_notary_key(self):
+    def test_get_notary_key(self) -> None:
         """Fetch a key belonging to the notary server"""
         # make up a key to be fetched. We randomise the keyid to try to get it to
         # appear before the key server signing key sometimes (otherwise we bail out
@@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             signedjson.key.encode_verify_key_base64(testkey.verify_key),
         )
 
-    def test_get_notary_keyserver_key(self):
+    def test_get_notary_keyserver_key(self) -> None:
         """Fetch the notary's keyserver key"""
         # we expect hs1 to make a regular key request to itself
         self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py
index f761e23f1b..c73179151a 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/rest/media/v1/test_base.py
@@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
         b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
     }
 
-    def tests(self):
+    def tests(self) -> None:
         for hdr, expected in self.TEST_CASES.items():
             res = get_filename_from_headers({b"Content-Disposition": [hdr]})
             self.assertEqual(
                 res,
                 expected,
-                "expected output for %s to be %s but was %s" % (hdr, expected, res),
+                f"expected output for {hdr!r} to be {expected} but was {res}",
             )
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
index 913bc530aa..43e6f0f70a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/rest/media/v1/test_filepath.py
@@ -21,12 +21,12 @@ from tests import unittest
 
 
 class MediaFilePathsTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
 
         self.filepaths = MediaFilePaths("/media_store")
 
-    def test_local_media_filepath(self):
+    def test_local_media_filepath(self) -> None:
         """Test local media paths"""
         self.assertEqual(
             self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_local_media_thumbnail(self):
+    def test_local_media_thumbnail(self) -> None:
         """Test local media thumbnail paths"""
         self.assertEqual(
             self.filepaths.local_media_thumbnail_rel(
@@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
         )
 
-    def test_local_media_thumbnail_dir(self):
+    def test_local_media_thumbnail_dir(self) -> None:
         """Test local media thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
             "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_remote_media_filepath(self):
+    def test_remote_media_filepath(self) -> None:
         """Test remote media paths"""
         self.assertEqual(
             self.filepaths.remote_media_filepath_rel(
@@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_remote_media_thumbnail(self):
+    def test_remote_media_thumbnail(self) -> None:
         """Test remote media thumbnail paths"""
         self.assertEqual(
             self.filepaths.remote_media_thumbnail_rel(
@@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
         )
 
-    def test_remote_media_thumbnail_legacy(self):
+    def test_remote_media_thumbnail_legacy(self) -> None:
         """Test old-style remote media thumbnail paths"""
         self.assertEqual(
             self.filepaths.remote_media_thumbnail_rel_legacy(
@@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
         )
 
-    def test_remote_media_thumbnail_dir(self):
+    def test_remote_media_thumbnail_dir(self) -> None:
         """Test remote media thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.remote_media_thumbnail_dir(
@@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_url_cache_filepath(self):
+    def test_url_cache_filepath(self) -> None:
         """Test URL cache paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
@@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
         )
 
-    def test_url_cache_filepath_legacy(self):
+    def test_url_cache_filepath_legacy(self) -> None:
         """Test old-style URL cache paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_url_cache_filepath_dirs_to_delete(self):
+    def test_url_cache_filepath_dirs_to_delete(self) -> None:
         """Test URL cache cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ["/media_store/url_cache/2020-01-02"],
         )
 
-    def test_url_cache_filepath_dirs_to_delete_legacy(self):
+    def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
         """Test old-style URL cache cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_url_cache_thumbnail(self):
+    def test_url_cache_thumbnail(self) -> None:
         """Test URL cache thumbnail paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_rel(
@@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
         )
 
-    def test_url_cache_thumbnail_legacy(self):
+    def test_url_cache_thumbnail_legacy(self) -> None:
         """Test old-style URL cache thumbnail paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_rel(
@@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
         )
 
-    def test_url_cache_thumbnail_directory(self):
+    def test_url_cache_thumbnail_directory(self) -> None:
         """Test URL cache thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_directory_rel(
@@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
         )
 
-    def test_url_cache_thumbnail_directory_legacy(self):
+    def test_url_cache_thumbnail_directory_legacy(self) -> None:
         """Test old-style URL cache thumbnail directory paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_directory_rel(
@@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
         )
 
-    def test_url_cache_thumbnail_dirs_to_delete(self):
+    def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
         """Test URL cache thumbnail cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+    def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
         """Test old-style URL cache thumbnail cleanup paths"""
         self.assertEqual(
             self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_server_name_validation(self):
+    def test_server_name_validation(self) -> None:
         """Test validation of server names"""
         self._test_path_validation(
             [
@@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_file_id_validation(self):
+    def test_file_id_validation(self) -> None:
         """Test validation of local, remote and legacy URL cache file / media IDs"""
         # File / media IDs get split into three parts to form paths, consisting of the
         # first two characters, next two characters and rest of the ID.
@@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             invalid_values=invalid_file_ids,
         )
 
-    def test_url_cache_media_id_validation(self):
+    def test_url_cache_media_id_validation(self) -> None:
         """Test validation of URL cache media IDs"""
         self._test_path_validation(
             [
@@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_content_type_validation(self):
+    def test_content_type_validation(self) -> None:
         """Test validation of thumbnail content types"""
         self._test_path_validation(
             [
@@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
             ],
         )
 
-    def test_thumbnail_method_validation(self):
+    def test_thumbnail_method_validation(self) -> None:
         """Test validation of thumbnail methods"""
         self._test_path_validation(
             [
@@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
         parameter: str,
         valid_values: Iterable[str],
         invalid_values: Iterable[str],
-    ):
+    ) -> None:
         """Test that the specified methods validate the named parameter as expected
 
         Args:
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py
index a4b57e3d1f..62e308814d 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/rest/media/v1/test_html_preview.py
@@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import (
     _get_html_media_encodings,
     decode_body,
     parse_html_to_open_graph,
-    rebase_url,
     summarize_paragraphs,
 )
 
@@ -32,7 +31,7 @@ class SummarizeTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
-    def test_long_summarize(self):
+    def test_long_summarize(self) -> None:
         example_paras = [
             """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
             Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
@@ -90,7 +89,7 @@ class SummarizeTestCase(unittest.TestCase):
             " Tromsøya had a population of 36,088. Substantial parts of the urban…",
         )
 
-    def test_short_summarize(self):
+    def test_short_summarize(self) -> None:
         example_paras = [
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -117,7 +116,7 @@ class SummarizeTestCase(unittest.TestCase):
             " most of the year.",
         )
 
-    def test_small_then_large_summarize(self):
+    def test_small_then_large_summarize(self) -> None:
         example_paras = [
             "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
             " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -150,7 +149,7 @@ class CalcOgTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
-    def test_simple(self):
+    def test_simple(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -161,11 +160,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_comment(self):
+    def test_comment(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -177,11 +176,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_comment2(self):
+    def test_comment2(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(
             og,
@@ -206,7 +205,7 @@ class CalcOgTestCase(unittest.TestCase):
             },
         )
 
-    def test_script(self):
+    def test_script(self) -> None:
         html = b"""
         <html>
         <head><title>Foo</title></head>
@@ -218,11 +217,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_missing_title(self):
+    def test_missing_title(self) -> None:
         html = b"""
         <html>
         <body>
@@ -232,11 +231,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
-    def test_h1_as_title(self):
+    def test_h1_as_title(self) -> None:
         html = b"""
         <html>
         <meta property="og:description" content="Some text."/>
@@ -247,11 +246,11 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
-    def test_missing_title_and_broken_h1(self):
+    def test_missing_title_and_broken_h1(self) -> None:
         html = b"""
         <html>
         <body>
@@ -262,23 +261,23 @@ class CalcOgTestCase(unittest.TestCase):
         """
 
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
-    def test_empty(self):
+    def test_empty(self) -> None:
         """Test a body with no data in it."""
         html = b""
         tree = decode_body(html, "http://example.com/test.html")
         self.assertIsNone(tree)
 
-    def test_no_tree(self):
+    def test_no_tree(self) -> None:
         """A valid body with no tree in it."""
         html = b"\x00"
         tree = decode_body(html, "http://example.com/test.html")
         self.assertIsNone(tree)
 
-    def test_xml(self):
+    def test_xml(self) -> None:
         """Test decoding XML and ensure it works properly."""
         # Note that the strip() call is important to ensure the xml tag starts
         # at the initial byte.
@@ -290,10 +289,10 @@ class CalcOgTestCase(unittest.TestCase):
         <head><title>Foo</title></head><body>Some text.</body></html>
         """.strip()
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_invalid_encoding(self):
+    def test_invalid_encoding(self) -> None:
         """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
         html = b"""
         <html>
@@ -304,10 +303,10 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
-    def test_invalid_encoding2(self):
+    def test_invalid_encoding2(self) -> None:
         """A body which doesn't match the sent character encoding."""
         # Note that this contains an invalid UTF-8 sequence in the title.
         html = b"""
@@ -319,10 +318,10 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
 
-    def test_windows_1252(self):
+    def test_windows_1252(self) -> None:
         """A body which uses cp1252, but doesn't declare that."""
         html = b"""
         <html>
@@ -333,12 +332,12 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
         tree = decode_body(html, "http://example.com/test.html")
-        og = parse_html_to_open_graph(tree, "http://example.com/test.html")
+        og = parse_html_to_open_graph(tree)
         self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
 
 
 class MediaEncodingTestCase(unittest.TestCase):
-    def test_meta_charset(self):
+    def test_meta_charset(self) -> None:
         """A character encoding is found via the meta tag."""
         encodings = _get_html_media_encodings(
             b"""
@@ -363,7 +362,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
-    def test_meta_charset_underscores(self):
+    def test_meta_charset_underscores(self) -> None:
         """A character encoding contains underscore."""
         encodings = _get_html_media_encodings(
             b"""
@@ -376,7 +375,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
 
-    def test_xml_encoding(self):
+    def test_xml_encoding(self) -> None:
         """A character encoding is found via the meta tag."""
         encodings = _get_html_media_encodings(
             b"""
@@ -388,7 +387,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
-    def test_meta_xml_encoding(self):
+    def test_meta_xml_encoding(self) -> None:
         """Meta tags take precedence over XML encoding."""
         encodings = _get_html_media_encodings(
             b"""
@@ -402,7 +401,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
 
-    def test_content_type(self):
+    def test_content_type(self) -> None:
         """A character encoding is found via the Content-Type header."""
         # Test a few variations of the header.
         headers = (
@@ -417,12 +416,12 @@ class MediaEncodingTestCase(unittest.TestCase):
             encodings = _get_html_media_encodings(b"", header)
             self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
-    def test_fallback(self):
+    def test_fallback(self) -> None:
         """A character encoding cannot be found in the body or header."""
         encodings = _get_html_media_encodings(b"", "text/html")
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
 
-    def test_duplicates(self):
+    def test_duplicates(self) -> None:
         """Ensure each encoding is only attempted once."""
         encodings = _get_html_media_encodings(
             b"""
@@ -436,7 +435,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         )
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
 
-    def test_unknown_invalid(self):
+    def test_unknown_invalid(self) -> None:
         """A character encoding should be ignored if it is unknown or invalid."""
         encodings = _get_html_media_encodings(
             b"""
@@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase):
             'text/html; charset="invalid"',
         )
         self.assertEqual(list(encodings), ["utf-8", "cp1252"])
-
-
-class RebaseUrlTestCase(unittest.TestCase):
-    def test_relative(self):
-        """Relative URLs should be resolved based on the context of the base URL."""
-        self.assertEqual(
-            rebase_url("subpage", "https://example.com/foo/"),
-            "https://example.com/foo/subpage",
-        )
-        self.assertEqual(
-            rebase_url("sibling", "https://example.com/foo"),
-            "https://example.com/sibling",
-        )
-        self.assertEqual(
-            rebase_url("/bar", "https://example.com/foo/"),
-            "https://example.com/bar",
-        )
-
-    def test_absolute(self):
-        """Absolute URLs should not be modified."""
-        self.assertEqual(
-            rebase_url("https://alice.com/a/", "https://example.com/foo/"),
-            "https://alice.com/a/",
-        )
-
-    def test_data(self):
-        """Data URLs should not be modified."""
-        self.assertEqual(
-            rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
-            "data:,Hello%2C%20World%21",
-        )
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index cba9be17c4..7204b2dfe0 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
 import tempfile
 from binascii import unhexlify
 from io import BytesIO
-from typing import Optional
+from typing import Any, BinaryIO, Dict, List, Optional, Union
 from unittest.mock import Mock
 from urllib import parse
 
@@ -26,18 +26,24 @@ from PIL import Image as Image
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.events import EventBase
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.logging.context import make_deferred_yieldable
+from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+from synapse.server import HomeServer
+from synapse.types import RoomAlias
+from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
 from tests.test_utils import SMALL_PNG
 from tests.utils import default_config
 
@@ -46,7 +52,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
 
     needs_threadpool = True
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
         self.addCleanup(shutil.rmtree, self.test_dir)
 
@@ -62,7 +68,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
             hs, self.primary_base_path, self.filepaths, storage_providers
         )
 
-    def test_ensure_media_is_in_local_cache(self):
+    def test_ensure_media_is_in_local_cache(self) -> None:
         media_id = "some_media_id"
         test_body = "Test\n"
 
@@ -105,7 +111,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
         self.assertEqual(test_body, body)
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
 class _TestImage:
     """An image for testing thumbnailing with the expected results
 
@@ -121,18 +127,18 @@ class _TestImage:
             a 404 is expected.
     """
 
-    data = attr.ib(type=bytes)
-    content_type = attr.ib(type=bytes)
-    extension = attr.ib(type=bytes)
-    expected_cropped = attr.ib(type=Optional[bytes], default=None)
-    expected_scaled = attr.ib(type=Optional[bytes], default=None)
-    expected_found = attr.ib(default=True, type=bool)
+    data: bytes
+    content_type: bytes
+    extension: bytes
+    expected_cropped: Optional[bytes] = None
+    expected_scaled: Optional[bytes] = None
+    expected_found: bool = True
 
 
 @parameterized_class(
     ("test_image",),
     [
-        # smoll png
+        # small png
         (
             _TestImage(
                 SMALL_PNG,
@@ -193,11 +199,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
     hijack_auth = True
     user_id = "@test:user"
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         self.fetches = []
 
-        def get_file(destination, path, output_stream, args=None, max_size=None):
+        def get_file(
+            destination: str,
+            path: str,
+            output_stream: BinaryIO,
+            args: Optional[Dict[str, Union[str, List[str]]]] = None,
+            max_size: Optional[int] = None,
+        ) -> Deferred:
             """
             Returns tuple[int,dict,str,int] of file length, response headers,
             absolute URI, and response code.
@@ -238,7 +250,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         media_resource = hs.get_media_repository_resource()
         self.download_resource = media_resource.children[b"download"]
@@ -248,8 +260,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         self.media_id = "example.com/12345"
 
-    def _req(self, content_disposition, include_content_type=True):
-
+    def _req(
+        self, content_disposition: Optional[bytes], include_content_type: bool = True
+    ) -> FakeChannel:
         channel = make_request(
             self.reactor,
             FakeSite(self.download_resource, self.reactor),
@@ -288,7 +301,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
         return channel
 
-    def test_handle_missing_content_type(self):
+    def test_handle_missing_content_type(self) -> None:
         channel = self._req(
             b"inline; filename=out" + self.test_image.extension,
             include_content_type=False,
@@ -299,7 +312,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
         )
 
-    def test_disposition_filename_ascii(self):
+    def test_disposition_filename_ascii(self) -> None:
         """
         If the filename is filename=<ascii> then Synapse will decode it as an
         ASCII string, and use filename= in the response.
@@ -315,7 +328,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             [b"inline; filename=out" + self.test_image.extension],
         )
 
-    def test_disposition_filenamestar_utf8escaped(self):
+    def test_disposition_filenamestar_utf8escaped(self) -> None:
         """
         If the filename is filename=*utf8''<utf8 escaped> then Synapse will
         correctly decode it as the UTF-8 string, and use filename* in the
@@ -335,7 +348,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
         )
 
-    def test_disposition_none(self):
+    def test_disposition_none(self) -> None:
         """
         If there is no filename, one isn't passed on in the Content-Disposition
         of the request.
@@ -348,26 +361,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         )
         self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
 
-    def test_thumbnail_crop(self):
+    def test_thumbnail_crop(self) -> None:
         """Test that a cropped remote thumbnail is available."""
         self._test_thumbnail(
             "crop", self.test_image.expected_cropped, self.test_image.expected_found
         )
 
-    def test_thumbnail_scale(self):
+    def test_thumbnail_scale(self) -> None:
         """Test that a scaled remote thumbnail is available."""
         self._test_thumbnail(
             "scale", self.test_image.expected_scaled, self.test_image.expected_found
         )
 
-    def test_invalid_type(self):
+    def test_invalid_type(self) -> None:
         """An invalid thumbnail type is never available."""
         self._test_thumbnail("invalid", None, False)
 
     @unittest.override_config(
         {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
     )
-    def test_no_thumbnail_crop(self):
+    def test_no_thumbnail_crop(self) -> None:
         """
         Override the config to generate only scaled thumbnails, but request a cropped one.
         """
@@ -376,13 +389,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
     @unittest.override_config(
         {"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
     )
-    def test_no_thumbnail_scale(self):
+    def test_no_thumbnail_scale(self) -> None:
         """
         Override the config to generate only cropped thumbnails, but request a scaled one.
         """
         self._test_thumbnail("scale", None, False)
 
-    def test_thumbnail_repeated_thumbnail(self):
+    def test_thumbnail_repeated_thumbnail(self) -> None:
         """Test that fetching the same thumbnail works, and deleting the on disk
         thumbnail regenerates it.
         """
@@ -443,7 +456,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                 channel.result["body"],
             )
 
-    def _test_thumbnail(self, method, expected_body, expected_found):
+    def _test_thumbnail(
+        self, method: str, expected_body: Optional[bytes], expected_found: bool
+    ) -> None:
         params = "?width=32&height=32&method=" + method
         channel = make_request(
             self.reactor,
@@ -485,7 +500,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             )
 
     @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
-    def test_same_quality(self, method, desired_size):
+    def test_same_quality(self, method: str, desired_size: int) -> None:
         """Test that choosing between thumbnails with the same quality rating succeeds.
 
         We are not particular about which thumbnail is chosen."""
@@ -521,7 +536,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             )
         )
 
-    def test_x_robots_tag_header(self):
+    def test_x_robots_tag_header(self) -> None:
         """
         Tests that the `X-Robots-Tag` header is present, which informs web crawlers
         to not index, archive, or follow links in media.
@@ -540,29 +555,38 @@ class TestSpamChecker:
     `evil`.
     """
 
-    def __init__(self, config, api):
+    def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
         self.config = config
         self.api = api
 
-    def parse_config(config):
+    def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
         return config
 
-    async def check_event_for_spam(self, foo):
+    async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
         return False  # allow all events
 
-    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+    async def user_may_invite(
+        self,
+        inviter_userid: str,
+        invitee_userid: str,
+        room_id: str,
+    ) -> bool:
         return True  # allow all invites
 
-    async def user_may_create_room(self, userid):
+    async def user_may_create_room(self, userid: str) -> bool:
         return True  # allow all room creations
 
-    async def user_may_create_room_alias(self, userid, room_alias):
+    async def user_may_create_room_alias(
+        self, userid: str, room_alias: RoomAlias
+    ) -> bool:
         return True  # allow all room aliases
 
-    async def user_may_publish_room(self, userid, room_id):
+    async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
         return True  # allow publishing of all rooms
 
-    async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+    async def check_media_file_for_spam(
+        self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
+    ) -> bool:
         buf = BytesIO()
         await file_wrapper.write_chunks_to(buf.write)
 
@@ -575,7 +599,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
         admin.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user = self.register_user("user", "pass")
         self.tok = self.login("user", "pass")
 
@@ -586,7 +610,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
         load_legacy_spam_checkers(hs)
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = default_config("test")
 
         config.update(
@@ -602,13 +626,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
         return config
 
-    def test_upload_innocent(self):
+    def test_upload_innocent(self) -> None:
         """Attempt to upload some innocent data that should be allowed."""
         self.helper.upload_media(
             self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
         )
 
-    def test_upload_ban(self):
+    def test_upload_ban(self) -> None:
         """Attempt to upload some data that includes bytes "evil", which should
         get rejected by the spam checker.
         """
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
index 048d0ca44a..f38d7225f8 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/rest/media/v1/test_oembed.py
@@ -16,7 +16,7 @@ import json
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
 
 
 class OEmbedTests(HomeserverTestCase):
-    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
-        self.oembed = OEmbedProvider(homeserver)
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.oembed = OEmbedProvider(hs)
 
-    def parse_response(self, response: JsonDict):
+    def parse_response(self, response: JsonDict) -> OEmbedResult:
         return self.oembed.parse_oembed_response(
             "https://test", json.dumps(response).encode("utf-8")
         )
 
-    def test_version(self):
+    def test_version(self) -> None:
         """Accept versions that are similar to 1.0 as a string or int (or missing)."""
         for version in ("1.0", 1.0, 1):
             result = self.parse_response({"version": version, "type": "link"})
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index da2c533260..5148c39874 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -16,16 +16,21 @@ import base64
 import json
 import os
 import re
+from typing import Any, Dict, Optional, Sequence, Tuple, Type
 from urllib.parse import urlencode
 
 from twisted.internet._resolver import HostResolution
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.error import DNSLookupError
-from twisted.test.proto_helpers import AccumulatingProtocol
+from twisted.internet.interfaces import IAddress, IResolutionReceiver
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
 
 from synapse.config.oembed import OEmbedEndpointConfig
+from synapse.rest.media.v1.media_repository import MediaRepositoryResource
 from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.server import HomeServer
 from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
 from tests import unittest
@@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         b"</head></html>"
     )
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
         config["url_preview_enabled"] = True
@@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
 
         self.media_repo = hs.get_media_repository_resource()
         self.preview_url = self.media_repo.children[b"preview_url"]
 
-        self.lookups = {}
+        self.lookups: Dict[str, Any] = {}
 
         class Resolver:
             def resolveHostName(
                 _self,
-                resolutionReceiver,
-                hostName,
-                portNumber=0,
-                addressTypes=None,
-                transportSemantics="TCP",
-            ):
+                resolutionReceiver: IResolutionReceiver,
+                hostName: str,
+                portNumber: int = 0,
+                addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+                transportSemantics: str = "TCP",
+            ) -> IResolutionReceiver:
 
                 resolution = HostResolution(hostName)
                 resolutionReceiver.resolutionBegan(resolution)
@@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 resolutionReceiver.resolutionComplete()
                 return resolutionReceiver
 
-        self.reactor.nameResolver = Resolver()
+        self.reactor.nameResolver = Resolver()  # type: ignore[assignment]
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> MediaRepositoryResource:
         return self.hs.get_media_repository_resource()
 
     def _assert_small_png(self, json_body: JsonDict) -> None:
@@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(json_body["og:image:type"], "image/png")
         self.assertEqual(json_body["matrix:image:size"], 67)
 
-    def test_cache_returns_correct_type(self):
+    def test_cache_returns_correct_type(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         channel = self.make_request(
@@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
 
-    def test_non_ascii_preview_httpequiv(self):
+    def test_non_ascii_preview_httpequiv(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
@@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
-    def test_video_rejected(self):
+    def test_video_rejected(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = b"anything"
@@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_audio_rejected(self):
+    def test_audio_rejected(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = b"anything"
@@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_non_ascii_preview_content_type(self):
+    def test_non_ascii_preview_content_type(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
@@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
-    def test_overlong_title(self):
+    def test_overlong_title(self) -> None:
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
@@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # We should only see the `og:description` field, as `title` is too long and should be stripped out
         self.assertCountEqual(["og:description"], res.keys())
 
-    def test_ipaddr(self):
+    def test_ipaddr(self) -> None:
         """
         IP addresses can be previewed directly.
         """
@@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
 
-    def test_blacklisted_ip_specific(self):
+    def test_blacklisted_ip_specific(self) -> None:
         """
         Blacklisted IP addresses, found via DNS, are not spidered.
         """
@@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ip_range(self):
+    def test_blacklisted_ip_range(self) -> None:
         """
         Blacklisted IP ranges, IPs found over DNS, are not spidered.
         """
@@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ip_specific_direct(self):
+    def test_blacklisted_ip_specific_direct(self) -> None:
         """
         Blacklisted IP addresses, accessed directly, are not spidered.
         """
@@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 403)
 
-    def test_blacklisted_ip_range_direct(self):
+    def test_blacklisted_ip_range_direct(self) -> None:
         """
         Blacklisted IP ranges, accessed directly, are not spidered.
         """
@@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ip_range_whitelisted_ip(self):
+    def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
         """
         Blacklisted but then subsequently whitelisted IP addresses can be
         spidered.
@@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
         )
 
-    def test_blacklisted_ip_with_external_ip(self):
+    def test_blacklisted_ip_with_external_ip(self) -> None:
         """
         If a hostname resolves a blacklisted IP, even if there's a
         non-blacklisted one, it will be rejected.
@@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ipv6_specific(self):
+    def test_blacklisted_ipv6_specific(self) -> None:
         """
         Blacklisted IP addresses, found via DNS, are not spidered.
         """
@@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_blacklisted_ipv6_range(self):
+    def test_blacklisted_ipv6_range(self) -> None:
         """
         Blacklisted IP ranges, IPs found over DNS, are not spidered.
         """
@@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_OPTIONS(self):
+    def test_OPTIONS(self) -> None:
         """
         OPTIONS returns the OPTIONS.
         """
@@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {})
 
-    def test_accept_language_config_option(self):
+    def test_accept_language_config_option(self) -> None:
         """
         Accept-Language header is sent to the remote server
         """
@@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             server.data,
         )
 
-    def test_data_url(self):
+    def test_data_url(self) -> None:
         """
         Requesting to preview a data URL is not supported.
         """
@@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 500)
 
-    def test_inline_data_url(self):
+    def test_inline_data_url(self) -> None:
         """
         An inline image (as a data URL) should be parsed properly.
         """
@@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self._assert_small_png(channel.json_body)
 
-    def test_oembed_photo(self):
+    def test_oembed_photo(self) -> None:
         """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
         self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
         self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
         self._assert_small_png(body)
 
-    def test_oembed_rich(self):
+    def test_oembed_rich(self) -> None:
         """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
         self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
 
@@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_oembed_format(self):
+    def test_oembed_format(self) -> None:
         """Test an oEmbed endpoint which requires the format in the URL."""
         self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
 
@@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             },
         )
 
-    def test_oembed_autodiscovery(self):
+    def test_oembed_autodiscovery(self) -> None:
         """
         Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
         1. Request a preview of a URL which is not known to the oEmbed code.
@@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
         self._assert_small_png(body)
 
-    def _download_image(self):
+    def _download_image(self) -> Tuple[str, str]:
         """Downloads an image into the URL cache.
         Returns:
             A (host, media_id) tuple representing the MXC URI of the image.
@@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertIsNone(_port)
         return host, media_id
 
-    def test_storage_providers_exclude_files(self):
+    def test_storage_providers_exclude_files(self) -> None:
         """Test that files are not stored in or fetched from storage providers."""
         host, media_id = self._download_image()
 
@@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             "URL cache file was unexpectedly retrieved from a storage provider",
         )
 
-    def test_storage_providers_exclude_thumbnails(self):
+    def test_storage_providers_exclude_thumbnails(self) -> None:
         """Test that thumbnails are not stored in or fetched from storage providers."""
         host, media_id = self._download_image()
 
@@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             "URL cache thumbnail was unexpectedly retrieved from a storage provider",
         )
 
-    def test_cache_expiry(self):
+    def test_cache_expiry(self) -> None:
         """Test that URL cache files and thumbnails are cleaned up properly on expiry."""
         self.preview_url.clock = MockClock()
 
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 01d48c3860..da325955f8 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from http import HTTPStatus
 
 from synapse.rest.health import HealthResource
 
@@ -19,12 +19,12 @@ from tests import unittest
 
 
 class HealthCheckTests(unittest.HomeserverTestCase):
-    def create_test_resource(self):
+    def create_test_resource(self) -> HealthResource:
         # replace the JsonResource with a HealthResource.
         return HealthResource()
 
-    def test_health(self):
+    def test_health(self) -> None:
         channel = self.make_request("GET", "/health", shorthand=False)
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index 118aa93a32..11f78f52b8 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -11,6 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+
 from twisted.web.resource import Resource
 
 from synapse.rest.well_known import well_known_resource
@@ -19,7 +21,7 @@ from tests import unittest
 
 
 class WellKnownTests(unittest.HomeserverTestCase):
-    def create_test_resource(self):
+    def create_test_resource(self) -> Resource:
         # replace the JsonResource with a Resource wrapping the WellKnownResource
         res = Resource()
         res.putChild(b".well-known", well_known_resource(self.hs))
@@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
             "default_identity_server": "https://testis",
         }
     )
-    def test_client_well_known(self):
+    def test_client_well_known(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(
             channel.json_body,
             {
@@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
             "public_baseurl": None,
         }
     )
-    def test_client_well_known_no_public_baseurl(self):
+    def test_client_well_known_no_public_baseurl(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/client", shorthand=False
         )
 
-        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
 
     @unittest.override_config({"serve_server_wellknown": True})
-    def test_server_well_known(self):
+    def test_server_well_known(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/server", shorthand=False
         )
 
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, HTTPStatus.OK)
         self.assertEqual(
             channel.json_body,
             {"m.server": "test:443"},
         )
 
-    def test_server_well_known_disabled(self):
+    def test_server_well_known_disabled(self) -> None:
         channel = self.make_request(
             "GET", "/.well-known/matrix/server", shorthand=False
         )
-        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
diff --git a/tests/server.py b/tests/server.py
index 82990c2eb9..6ce2a17bf4 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -54,13 +54,18 @@ from twisted.internet.interfaces import (
     ITransport,
 )
 from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import (
+    AccumulatingProtocol,
+    MemoryReactor,
+    MemoryReactorClock,
+)
 from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.http.site import SynapseRequest
+from synapse.logging.context import ContextResourceUsage
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import PostgresEngine, create_engine
@@ -88,18 +93,19 @@ class TimedOutException(Exception):
     """
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class FakeChannel:
     """
     A fake Twisted Web Channel (the part that interfaces with the
     wire).
     """
 
-    site = attr.ib(type=Union[Site, "FakeSite"])
-    _reactor = attr.ib()
-    result = attr.ib(type=dict, default=attr.Factory(dict))
-    _ip = attr.ib(type=str, default="127.0.0.1")
+    site: Union[Site, "FakeSite"]
+    _reactor: MemoryReactor
+    result: dict = attr.Factory(dict)
+    _ip: str = "127.0.0.1"
     _producer: Optional[Union[IPullProducer, IPushProducer]] = None
+    resource_usage: Optional[ContextResourceUsage] = None
 
     @property
     def json_body(self):
@@ -168,6 +174,8 @@ class FakeChannel:
 
     def requestDone(self, _self):
         self.result["done"] = True
+        if isinstance(_self, SynapseRequest):
+            self.resource_usage = _self.logcontext.get_resource_usage()
 
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 272cd35402..72bf5b3d31 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
             expected_ignorer_user_ids,
         )
 
+    def assert_ignored(
+        self, ignorer_user_id: str, expected_ignored_user_ids: Set[str]
+    ) -> None:
+        self.assertEqual(
+            self.get_success(self.store.ignored_users(ignorer_user_id)),
+            expected_ignored_user_ids,
+        )
+
     def test_ignoring_users(self):
         """Basic adding/removing of users from the ignore list."""
         self._update_ignore_list("@other:test", "@another:remote")
+        self.assert_ignored(self.user, {"@other:test", "@another:remote"})
 
         # Check a user which no one ignores.
         self.assert_ignorers("@user:test", set())
@@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
 
         # Add one user, remove one user, and leave one user.
         self._update_ignore_list("@foo:test", "@another:remote")
+        self.assert_ignored(self.user, {"@foo:test", "@another:remote"})
 
         # Check the removed user.
         self.assert_ignorers("@other:test", set())
@@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         """Ensure that caching works properly between different users."""
         # The first user ignores a user.
         self._update_ignore_list("@other:test")
+        self.assert_ignored(self.user, {"@other:test"})
         self.assert_ignorers("@other:test", {self.user})
 
         # The second user ignores them.
         self._update_ignore_list("@other:test", ignorer_user_id="@second:test")
+        self.assert_ignored("@second:test", {"@other:test"})
         self.assert_ignorers("@other:test", {self.user, "@second:test"})
 
         # The first user un-ignores them.
         self._update_ignore_list()
+        self.assert_ignored(self.user, set())
         self.assert_ignorers("@other:test", {"@second:test"})
 
     def test_invalid_data(self):
         """Invalid data ends up clearing out the ignored users list."""
         # Add some data and ensure it is there.
         self._update_ignore_list("@other:test")
+        self.assert_ignored(self.user, {"@other:test"})
         self.assert_ignorers("@other:test", {self.user})
 
         # No ignored_users key.
@@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         )
 
         # No one ignores the user now.
+        self.assert_ignored(self.user, set())
         self.assert_ignorers("@other:test", set())
 
         # Add some data and ensure it is there.
         self._update_ignore_list("@other:test")
+        self.assert_ignored(self.user, {"@other:test"})
         self.assert_ignorers("@other:test", {self.user})
 
         # Invalid data.
@@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         )
 
         # No one ignores the user now.
+        self.assert_ignored(self.user, set())
         self.assert_ignorers("@other:test", set())
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 39dcc094bd..fd619b64d4 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -14,16 +14,23 @@
 
 from unittest.mock import Mock
 
+import yaml
+
 from twisted.internet.defer import Deferred, ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
 from synapse.storage.background_updates import BackgroundUpdater
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable, simple_async_mock
+from tests.unittest import override_config
 
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
@@ -34,50 +41,50 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         self.updates.register_background_update_handler(
             "test_update", self.update_handler
         )
+        self.store = self.hs.get_datastores().main
+
+    async def update(self, progress: JsonDict, count: int) -> int:
+        duration_ms = 10
+        await self.clock.sleep((count * duration_ms) / 1000)
+        progress = {"my_key": progress["my_key"] + 1}
+        await self.store.db_pool.runInteraction(
+            "update_progress",
+            self.updates._background_update_progress_txn,
+            "test_update",
+            progress,
+        )
+        return count
 
-    def test_do_background_update(self):
+    def test_do_background_update(self) -> None:
         # the time we claim it takes to update one item when running the update
         duration_ms = 10
 
         # the target runtime for each bg update
         target_background_update_duration_ms = 100
 
-        store = self.hs.get_datastores().main
         self.get_success(
-            store.db_pool.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
             )
         )
 
-        # first step: make a bit of progress
-        async def update(progress, count):
-            await self.clock.sleep((count * duration_ms) / 1000)
-            progress = {"my_key": progress["my_key"] + 1}
-            await store.db_pool.runInteraction(
-                "update_progress",
-                self.updates._background_update_progress_txn,
-                "test_update",
-                progress,
-            )
-            return count
-
-        self.update_handler.side_effect = update
+        self.update_handler.side_effect = self.update
         self.update_handler.reset_mock()
         res = self.get_success(
             self.updates.do_next_background_update(False),
-            by=0.01,
+            by=0.02,
         )
         self.assertFalse(res)
 
         # on the first call, we should get run with the default background update size
         self.update_handler.assert_called_once_with(
-            {"my_key": 1}, self.updates.MINIMUM_BACKGROUND_BATCH_SIZE
+            {"my_key": 1}, self.updates.default_background_batch_size
         )
 
         # second step: complete the update
         # we should now get run with a much bigger number of items to update
-        async def update(progress, count):
+        async def update(progress: JsonDict, count: int) -> int:
             self.assertEqual(progress, {"my_key": 2})
             self.assertAlmostEqual(
                 count,
@@ -99,16 +106,234 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         self.assertTrue(result)
         self.assertFalse(self.update_handler.called)
 
+    @override_config(
+        yaml.safe_load(
+            """
+            background_updates:
+                default_batch_size: 20
+            """
+        )
+    )
+    def test_background_update_default_batch_set_by_config(self) -> None:
+        """
+        Test that the background update is run with the default_batch_size set by the config
+        """
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+            )
+        )
+
+        self.update_handler.side_effect = self.update
+        self.update_handler.reset_mock()
+        res = self.get_success(
+            self.updates.do_next_background_update(False),
+            by=0.01,
+        )
+        self.assertFalse(res)
+
+        # on the first call, we should get run with the default background update size specified in the config
+        self.update_handler.assert_called_once_with({"my_key": 1}, 20)
+
+    def test_background_update_default_sleep_behavior(self) -> None:
+        """
+        Test default background update behavior, which is to sleep
+        """
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+            )
+        )
+
+        self.update_handler.side_effect = self.update
+        self.update_handler.reset_mock()
+        self.updates.start_doing_background_updates()
+
+        # 2: advance the reactor less than the default sleep duration (1000ms)
+        self.reactor.pump([0.5])
+        # check that an update has not been run
+        self.update_handler.assert_not_called()
+
+        # advance reactor past default sleep duration
+        self.reactor.pump([1])
+        # check that update has been run
+        self.update_handler.assert_called()
+
+    @override_config(
+        yaml.safe_load(
+            """
+            background_updates:
+                sleep_duration_ms: 500
+            """
+        )
+    )
+    def test_background_update_sleep_set_in_config(self) -> None:
+        """
+        Test that changing the sleep time in the config changes how long it sleeps
+        """
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+            )
+        )
+
+        self.update_handler.side_effect = self.update
+        self.update_handler.reset_mock()
+        self.updates.start_doing_background_updates()
+
+        # 2: advance the reactor less than the configured sleep duration (500ms)
+        self.reactor.pump([0.45])
+        # check that an update has not been run
+        self.update_handler.assert_not_called()
+
+        # advance reactor past config sleep duration but less than default duration
+        self.reactor.pump([0.75])
+        # check that update has been run
+        self.update_handler.assert_called()
+
+    @override_config(
+        yaml.safe_load(
+            """
+            background_updates:
+                sleep_enabled: false
+            """
+        )
+    )
+    def test_disabling_background_update_sleep(self) -> None:
+        """
+        Test that disabling sleep in the config results in bg update not sleeping
+        """
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+            )
+        )
+
+        self.update_handler.side_effect = self.update
+        self.update_handler.reset_mock()
+        self.updates.start_doing_background_updates()
+
+        # 2: advance the reactor very little
+        self.reactor.pump([0.025])
+        # check that an update has run
+        self.update_handler.assert_called()
+
+    @override_config(
+        yaml.safe_load(
+            """
+            background_updates:
+                background_update_duration_ms: 500
+            """
+        )
+    )
+    def test_background_update_duration_set_in_config(self) -> None:
+        """
+        Test that the desired duration set in the config is used in determining batch size
+        """
+        # Duration of one background update item
+        duration_ms = 10
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+            )
+        )
+
+        self.update_handler.side_effect = self.update
+        self.update_handler.reset_mock()
+        res = self.get_success(
+            self.updates.do_next_background_update(False),
+            by=0.02,
+        )
+        self.assertFalse(res)
+
+        # the first update was run with the default batch size, this should be run with 500ms as the
+        # desired duration
+        async def update(progress: JsonDict, count: int) -> int:
+            self.assertEqual(progress, {"my_key": 2})
+            self.assertAlmostEqual(
+                count,
+                500 / duration_ms,
+                places=0,
+            )
+            await self.updates._end_background_update("test_update")
+            return count
+
+        self.update_handler.side_effect = update
+        self.get_success(self.updates.do_next_background_update(False))
+
+    @override_config(
+        yaml.safe_load(
+            """
+            background_updates:
+                min_batch_size: 5
+            """
+        )
+    )
+    def test_background_update_min_batch_set_in_config(self) -> None:
+        """
+        Test that the minimum batch size set in the config is used
+        """
+        # a very long-running individual update
+        duration_ms = 50
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+            )
+        )
+
+        # Run the update with the long-running update item
+        async def update_long(progress: JsonDict, count: int) -> int:
+            await self.clock.sleep((count * duration_ms) / 1000)
+            progress = {"my_key": progress["my_key"] + 1}
+            await self.store.db_pool.runInteraction(
+                "update_progress",
+                self.updates._background_update_progress_txn,
+                "test_update",
+                progress,
+            )
+            return count
+
+        self.update_handler.side_effect = update_long
+        self.update_handler.reset_mock()
+        res = self.get_success(
+            self.updates.do_next_background_update(False),
+            by=1,
+        )
+        self.assertFalse(res)
+
+        # the first update was run with the default batch size, this should be run with minimum batch size
+        # as the first items took a very long time
+        async def update_short(progress: JsonDict, count: int) -> int:
+            self.assertEqual(progress, {"my_key": 2})
+            self.assertEqual(count, 5)
+            await self.updates._end_background_update("test_update")
+            return count
+
+        self.update_handler.side_effect = update_short
+        self.get_success(self.updates.do_next_background_update(False))
+
 
 class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
         )
 
-        self.update_deferred = Deferred()
+        self.update_deferred: Deferred[int] = Deferred()
         self.update_handler = Mock(return_value=self.update_deferred)
         self.updates.register_background_update_handler(
             "test_update", self.update_handler
@@ -137,7 +362,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
             ),
         )
 
-    def test_controller(self):
+    def test_controller(self) -> None:
         store = self.hs.get_datastores().main
         self.get_success(
             store.db_pool.simple_insert(
@@ -147,7 +372,7 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         )
 
         # Set the return value for the context manager.
-        enter_defer = Deferred()
+        enter_defer: Deferred[int] = Deferred()
         self._update_ctx_manager.__aenter__ = Mock(return_value=enter_defer)
 
         # Start the background update.
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 6fbac0ab14..a40fc20ef9 100644
--- a/tests/storage/test_database.py
+++ b/tests/storage/test_database.py
@@ -12,25 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.database import make_tuple_comparison_clause
-from synapse.storage.engines import BaseDatabaseEngine
+from typing import Callable, Tuple
+from unittest.mock import Mock, call
 
-from tests import unittest
+from twisted.internet import defer
+from twisted.internet.defer import CancelledError, Deferred
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
+from synapse.util import Clock
 
-def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
-    # returns a DatabaseEngine, circumventing the abc mechanism
-    # any kwargs are set as attributes on the class before instantiating it
-    t = type(
-        "TestBaseDatabaseEngine",
-        (BaseDatabaseEngine,),
-        dict(BaseDatabaseEngine.__dict__),
-    )
-    # defeat the abc mechanism
-    t.__abstractmethods__ = set()
-    for k, v in kwargs.items():
-        setattr(t, k, v)
-    return t(None, None)
+from tests import unittest
 
 
 class TupleComparisonClauseTestCase(unittest.TestCase):
@@ -38,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase):
         clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
         self.assertEqual(clause, "(a,b) > (?,?)")
         self.assertEqual(args, [1, 2])
+
+
+class CallbacksTestCase(unittest.HomeserverTestCase):
+    """Tests for transaction callbacks."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.db_pool: DatabasePool = self.store.db_pool
+
+    def _run_interaction(
+        self, func: Callable[[LoggingTransaction], object]
+    ) -> Tuple[Mock, Mock]:
+        """Run the given function in a database transaction, with callbacks registered.
+
+        Args:
+            func: The function to be run in a transaction. The transaction will be
+                retried if `func` raises an `OperationalError`.
+
+        Returns:
+            Two mocks, which were registered as an `after_callback` and an
+            `exception_callback` respectively, on every transaction attempt.
+        """
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            func(txn)
+
+        try:
+            self.get_success_or_raise(
+                self.db_pool.runInteraction("test_transaction", _test_txn)
+            )
+        except Exception:
+            pass
+
+        return after_callback, exception_callback
+
+    def test_after_callback(self) -> None:
+        """Test that the after callback is called when a transaction succeeds."""
+        after_callback, exception_callback = self._run_interaction(lambda txn: None)
+
+        after_callback.assert_called_once_with(123, 456, extra=789)
+        exception_callback.assert_not_called()
+
+    def test_exception_callback(self) -> None:
+        """Test that the exception callback is called when a transaction fails."""
+        _test_txn = Mock(side_effect=ZeroDivisionError)
+        after_callback, exception_callback = self._run_interaction(_test_txn)
+
+        after_callback.assert_not_called()
+        exception_callback.assert_called_once_with(987, 654, extra=321)
+
+    def test_failed_retry(self) -> None:
+        """Test that the exception callback is called for every failed attempt."""
+        # Always raise an `OperationalError`.
+        _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
+        after_callback, exception_callback = self._run_interaction(_test_txn)
+
+        after_callback.assert_not_called()
+        exception_callback.assert_has_calls(
+            [
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+            ]
+        )
+        self.assertEqual(exception_callback.call_count, 6)  # no additional calls
+
+    def test_successful_retry(self) -> None:
+        """Test callbacks for a failed transaction followed by a successful attempt."""
+        # Raise an `OperationalError` on the first attempt only.
+        _test_txn = Mock(
+            side_effect=[self.db_pool.engine.module.OperationalError, None]
+        )
+        after_callback, exception_callback = self._run_interaction(_test_txn)
+
+        # Calling both `after_callback`s when the first attempt failed is rather
+        # surprising (#12184). Let's document the behaviour in a test.
+        after_callback.assert_has_calls(
+            [
+                call(123, 456, extra=789),
+                call(123, 456, extra=789),
+            ]
+        )
+        self.assertEqual(after_callback.call_count, 2)  # no additional calls
+        exception_callback.assert_not_called()
+
+
+class CancellationTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.db_pool: DatabasePool = self.store.db_pool
+
+    def test_after_callback(self) -> None:
+        """Test that the after callback is called when a transaction succeeds."""
+        d: "Deferred[None]"
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            d.cancel()
+
+        d = defer.ensureDeferred(
+            self.db_pool.runInteraction("test_transaction", _test_txn)
+        )
+        self.get_failure(d, CancelledError)
+
+        after_callback.assert_called_once_with(123, 456, extra=789)
+        exception_callback.assert_not_called()
+
+    def test_exception_callback(self) -> None:
+        """Test that the exception callback is called when a transaction fails."""
+        d: "Deferred[None]"
+        after_callback = Mock()
+        exception_callback = Mock()
+
+        def _test_txn(txn: LoggingTransaction) -> None:
+            txn.call_after(after_callback, 123, 456, extra=789)
+            txn.call_on_exception(exception_callback, 987, 654, extra=321)
+            d.cancel()
+            # Simulate a retryable failure on every attempt.
+            raise self.db_pool.engine.module.OperationalError()
+
+        d = defer.ensureDeferred(
+            self.db_pool.runInteraction("test_transaction", _test_txn)
+        )
+        self.get_failure(d, CancelledError)
+
+        after_callback.assert_not_called()
+        exception_callback.assert_has_calls(
+            [
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+                call(987, 654, extra=321),
+            ]
+        )
+        self.assertEqual(exception_callback.call_count, 6)  # no additional calls
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 6ac4b93f98..395396340b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -13,9 +13,13 @@
 # limitations under the License.
 from typing import List, Optional
 
-from synapse.storage.database import DatabasePool
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.server import HomeServer
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
-    def _setup_db(self, txn):
+    def _setup_db(self, txn: LoggingTransaction) -> None:
         txn.execute("CREATE SEQUENCE foobar_seq")
         txn.execute(
             """
@@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
 
-    def _insert_rows(self, instance_name: str, number: int):
+    def _insert_rows(self, instance_name: str, number: int) -> None:
         """Insert N rows as the given instance, inserting with stream IDs pulled
         from the postgres sequence.
         """
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
                 txn.execute(
                     "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
@@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
-    def _insert_row_with_id(self, instance_name: str, stream_id: int):
+    def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None:
         """Insert one row as the given instance with given stream_id, updating
         the postgres sequence position to match.
         """
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             txn.execute(
                 "INSERT INTO foobar VALUES (?, ?)",
                 (
@@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
 
-    def test_empty(self):
+    def test_empty(self) -> None:
         """Test an ID generator against an empty database gives sensible
         current positions.
         """
@@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # The table is empty so we expect an empty map for positions
         self.assertEqual(id_gen.get_positions(), {})
 
-    def test_single_instance(self):
+    def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
         correctly.
         """
@@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 8)
 
@@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 8})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
-    def test_out_of_order_finish(self):
+    def test_out_of_order_finish(self) -> None:
         """Test that IDs persisted out of order are correctly handled"""
 
         # Prefill table with 7 rows written by 'master'
@@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 11})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
 
-    def test_multi_instance(self):
+    def test_multi_instance(self) -> None:
         """Test that reads and writes from multiple processes are handled
         correctly.
         """
@@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with first_id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 8)
 
@@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # ... but calling `get_next` on the second instance should give a unique
         # stream ID
 
-        async def _get_next_async():
+        async def _get_next_async2() -> None:
             async with second_id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 9)
 
@@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                     second_id_gen.get_positions(), {"first": 3, "second": 7}
                 )
 
-        self.get_success(_get_next_async())
+        self.get_success(_get_next_async2())
 
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
 
@@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen.advance("first", 8)
         self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
 
-    def test_get_next_txn(self):
+    def test_get_next_txn(self) -> None:
         """Test that the `get_next_txn` function works correctly."""
 
         # Prefill table with 7 rows written by 'master'
@@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
-        def _get_next_txn(txn):
+        def _get_next_txn(txn: LoggingTransaction) -> None:
             stream_id = id_gen.get_next_txn(txn)
             self.assertEqual(stream_id, 8)
 
@@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 8})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
-    def test_get_persisted_upto_position(self):
+    def test_get_persisted_upto_position(self) -> None:
         """Test that `get_persisted_upto_position` correctly tracks updates to
         positions.
         """
@@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen.advance("second", 15)
         self.assertEqual(id_gen.get_persisted_upto_position(), 11)
 
-    def test_get_persisted_upto_position_get_next(self):
+    def test_get_persisted_upto_position_get_next(self) -> None:
         """Test that `get_persisted_upto_position` correctly tracks updates to
         positions when `get_next` is called.
         """
@@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 6)
                 self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # `persisted_upto_position` in this case, then it will be correct in the
         # other cases that are tested above (since they'll hit the same code).
 
-    def test_restart_during_out_of_order_persistence(self):
+    def test_restart_during_out_of_order_persistence(self) -> None:
         """Test that restarting a process while another process is writing out
         of order updates are handled correctly.
         """
@@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen_worker.advance("master", 9)
         self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
 
-    def test_writer_config_change(self):
+    def test_writer_config_change(self) -> None:
         """Test that changing the writer config correctly works."""
 
         self._insert_row_with_id("first", 3)
@@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         # Check that we get a sane next stream ID with this new config.
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen_3.get_next() as stream_id:
                 self.assertEqual(stream_id, 6)
 
@@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
         self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
 
-    def test_sequence_consistency(self):
+    def test_sequence_consistency(self) -> None:
         """Test that we error out if the table and sequence diverges."""
 
         # Prefill with some rows
@@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
-    def _setup_db(self, txn):
+    def _setup_db(self, txn: LoggingTransaction) -> None:
         txn.execute("CREATE SEQUENCE foobar_seq")
         txn.execute(
             """
@@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         return self.get_success(self.db_pool.runWithConnection(_create))
 
-    def _insert_row(self, instance_name: str, stream_id: int):
+    def _insert_row(self, instance_name: str, stream_id: int) -> None:
         """Insert one row as the given instance with given stream_id."""
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             txn.execute(
                 "INSERT INTO foobar VALUES (?, ?)",
                 (
@@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
 
-    def test_single_instance(self):
+    def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
         correctly.
         """
         id_gen = self._create_id_generator()
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen.get_next() as stream_id:
                 self._insert_row("master", stream_id)
 
@@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
         self.assertEqual(id_gen.get_persisted_upto_position(), -1)
 
-        async def _get_next_async2():
+        async def _get_next_async2() -> None:
             async with id_gen.get_next_mult(3) as stream_ids:
                 for stream_id in stream_ids:
                     self._insert_row("master", stream_id)
@@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
         self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
 
-    def test_multiple_instance(self):
+    def test_multiple_instance(self) -> None:
         """Tests that having multiple instances that get advanced over
         federation works corretly.
         """
         id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
         id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
 
-        async def _get_next_async():
+        async def _get_next_async() -> None:
             async with id_gen_1.get_next() as stream_id:
                 self._insert_row("first", stream_id)
                 id_gen_2.advance("first", stream_id)
@@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
 
-        async def _get_next_async2():
+        async def _get_next_async2() -> None:
             async with id_gen_2.get_next() as stream_id:
                 self._insert_row("second", stream_id)
                 id_gen_1.advance("second", stream_id)
@@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
-    def _setup_db(self, txn):
+    def _setup_db(self, txn: LoggingTransaction) -> None:
         txn.execute("CREATE SEQUENCE foobar_seq")
         txn.execute(
             """
@@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         from the postgres sequence.
         """
 
-        def _insert(txn):
+        def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
                 txn.execute(
                     "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
@@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
-    def test_load_existing_stream(self):
+    def test_load_existing_stream(self) -> None:
         """Test creating ID gens with multiple tables that have rows from after
         the position in `stream_positions` table.
         """
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 6a1cf33054..eaa0d7d749 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase):
 
     def test_filter_relation_senders(self):
         # Messages which second user reacted to.
-        filter = {"io.element.relation_senders": [self.second_user_id]}
+        filter = {"related_by_senders": [self.second_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0].event_id, self.event_id_1)
 
         # Messages which third user reacted to.
-        filter = {"io.element.relation_senders": [self.third_user_id]}
+        filter = {"related_by_senders": [self.third_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0].event_id, self.event_id_2)
 
         # Messages which either user reacted to.
-        filter = {
-            "io.element.relation_senders": [self.second_user_id, self.third_user_id]
-        }
+        filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 2, chunk)
         self.assertCountEqual(
@@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase):
 
     def test_filter_relation_type(self):
         # Messages which have annotations.
-        filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
+        filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0].event_id, self.event_id_1)
 
         # Messages which have references.
-        filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
+        filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0].event_id, self.event_id_2)
 
         # Messages which have either annotations or references.
         filter = {
-            "io.element.relation_types": [
+            "related_by_rel_types": [
                 RelationTypes.ANNOTATION,
                 RelationTypes.REFERENCE,
             ]
@@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase):
     def test_filter_relation_senders_and_type(self):
         # Messages which second user reacted to.
         filter = {
-            "io.element.relation_senders": [self.second_user_id],
-            "io.element.relation_types": [RelationTypes.ANNOTATION],
+            "related_by_senders": [self.second_user_id],
+            "related_by_rel_types": [RelationTypes.ANNOTATION],
         }
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
@@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase):
             tok=self.second_tok,
         )
 
-        filter = {"io.element.relation_senders": [self.second_user_id]}
+        filter = {"related_by_senders": [self.second_user_id]}
         chunk = self._filter_messages(filter)
         self.assertEqual(len(chunk), 1, chunk)
         self.assertEqual(chunk[0].event_id, self.event_id_1)
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
new file mode 100644
index 0000000000..ba53c22818
--- /dev/null
+++ b/tests/storage/test_unsafe_locale.py
@@ -0,0 +1,46 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import MagicMock, patch
+
+from synapse.storage.database import make_conn
+from synapse.storage.engines._base import IncorrectDatabaseSetup
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class UnsafeLocaleTest(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    @patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale")
+    def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None:
+        mock_db_locale.return_value = ("B", "B")
+        database = self.hs.get_datastores().databases[0]
+
+        db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+        with self.assertRaises(IncorrectDatabaseSetup):
+            database.engine.check_database(db_conn)
+        with self.assertRaises(IncorrectDatabaseSetup):
+            database.engine.check_new_database(db_conn)
+        db_conn.close()
+
+    def test_safe_locale(self) -> None:
+        database = self.hs.get_datastores().databases[0]
+
+        db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+        with db_conn.cursor() as txn:
+            res = database.engine.get_db_locale(txn)
+        self.assertEqual(res, ("C", "C"))
+        db_conn.close()
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 219b5660b1..532e3fe9cd 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -13,11 +13,12 @@
 # limitations under the License.
 import logging
 from typing import Optional
+from unittest.mock import patch
 
 from synapse.api.room_versions import RoomVersions
-from synapse.events import EventBase
-from synapse.types import JsonDict
-from synapse.visibility import filter_events_for_server
+from synapse.events import EventBase, make_event_from_dict
+from synapse.types import JsonDict, create_requester
+from synapse.visibility import filter_events_for_client, filter_events_for_server
 
 from tests import unittest
 from tests.utils import create_room
@@ -185,3 +186,72 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         self.get_success(self.storage.persistence.persist_event(event, context))
         return event
+
+
+class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
+    def test_out_of_band_invite_rejection(self):
+        # this is where we have received an invite event over federation, and then
+        # rejected it.
+        invite_pdu = {
+            "room_id": "!room:id",
+            "depth": 1,
+            "auth_events": [],
+            "prev_events": [],
+            "origin_server_ts": 1,
+            "sender": "@someone:" + self.OTHER_SERVER_NAME,
+            "type": "m.room.member",
+            "state_key": "@user:test",
+            "content": {"membership": "invite"},
+        }
+        self.add_hashes_and_signatures(invite_pdu)
+        invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id
+
+        self.get_success(
+            self.hs.get_federation_server().on_invite_request(
+                self.OTHER_SERVER_NAME,
+                invite_pdu,
+                "9",
+            )
+        )
+
+        # stub out do_remotely_reject_invite so that we fall back to a locally-
+        # generated rejection
+        with patch.object(
+            self.hs.get_federation_handler(),
+            "do_remotely_reject_invite",
+            side_effect=Exception(),
+        ):
+            reject_event_id, _ = self.get_success(
+                self.hs.get_room_member_handler().remote_reject_invite(
+                    invite_event_id,
+                    txn_id=None,
+                    requester=create_requester("@user:test"),
+                    content={},
+                )
+            )
+
+        invite_event, reject_event = self.get_success(
+            self.hs.get_datastores().main.get_events_as_list(
+                [invite_event_id, reject_event_id]
+            )
+        )
+
+        # the invited user should be able to see both the invite and the rejection
+        self.assertEqual(
+            self.get_success(
+                filter_events_for_client(
+                    self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+                )
+            ),
+            [invite_event, reject_event],
+        )
+
+        # other users should see neither
+        self.assertEqual(
+            self.get_success(
+                filter_events_for_client(
+                    self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+                )
+            ),
+            [],
+        )
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 19741ffcda..48e616ac74 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -17,7 +17,7 @@ from typing import Set
 from unittest import mock
 
 from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
 
 from synapse.api.errors import SynapseError
 from synapse.logging.context import (
@@ -28,7 +28,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList, lru_cache
 
 from tests import unittest
 from tests.test_utils import get_awaitable_result
@@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase):
         self.assertEqual(r, "chips")
         obj.mock.assert_not_called()
 
+    @defer.inlineCallbacks
+    def test_cache_uncached_args(self):
+        """
+        Only the arguments not named in uncached_args should matter to the cache
+
+        Note that this is identical to test_cache_num_args, but provides the
+        arguments differently.
+        """
+
+        class Cls:
+            # Note that it is important that this is not the last argument to
+            # test behaviour of skipping arguments properly.
+            @descriptors.cached(uncached_args=("arg2",))
+            def fn(self, arg1, arg2, arg3):
+                return self.mock(arg1, arg2, arg3)
+
+            def __init__(self):
+                self.mock = mock.Mock()
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = yield obj.fn(1, 2, 3)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, 2, 3)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = yield obj.fn(2, 3, 4)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(2, 3, 4)
+        obj.mock.reset_mock()
+
+        # the two values should now be cached; we should be able to vary
+        # the second argument and still get the cached result.
+        r = yield obj.fn(1, 4, 3)
+        self.assertEqual(r, "fish")
+        r = yield obj.fn(2, 5, 4)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_not_called()
+
+    @defer.inlineCallbacks
+    def test_cache_kwargs(self):
+        """Test that keyword arguments are treated properly"""
+
+        class Cls:
+            def __init__(self):
+                self.mock = mock.Mock()
+
+            @descriptors.cached()
+            def fn(self, arg1, kwarg1=2):
+                return self.mock(arg1, kwarg1=kwarg1)
+
+        obj = Cls()
+        obj.mock.return_value = "fish"
+        r = yield obj.fn(1, kwarg1=2)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_called_once_with(1, kwarg1=2)
+        obj.mock.reset_mock()
+
+        # a call with different params should call the mock again
+        obj.mock.return_value = "chips"
+        r = yield obj.fn(1, kwarg1=3)
+        self.assertEqual(r, "chips")
+        obj.mock.assert_called_once_with(1, kwarg1=3)
+        obj.mock.reset_mock()
+
+        # the values should now be cached.
+        r = yield obj.fn(1, kwarg1=2)
+        self.assertEqual(r, "fish")
+        # We should be able to not provide kwarg1 and get the cached value back.
+        r = yield obj.fn(1)
+        self.assertEqual(r, "fish")
+        # Keyword arguments can be in any order.
+        r = yield obj.fn(kwarg1=2, arg1=1)
+        self.assertEqual(r, "fish")
+        obj.mock.assert_not_called()
+
     def test_cache_with_sync_exception(self):
         """If the wrapped function throws synchronously, things should continue to work"""
 
@@ -415,6 +493,74 @@ class DescriptorTestCase(unittest.TestCase):
         obj.invalidate()
         top_invalidate.assert_called_once()
 
+    def test_cancel(self):
+        """Test that cancelling a lookup does not cancel other lookups"""
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            @cached()
+            async def fn(self, arg1):
+                await complete_lookup
+                return str(arg1)
+
+        obj = Cls()
+
+        d1 = obj.fn(123)
+        d2 = obj.fn(123)
+        self.assertFalse(d1.called)
+        self.assertFalse(d2.called)
+
+        # Cancel `d1`, which is the lookup that caused `fn` to run.
+        d1.cancel()
+
+        # `d2` should complete normally.
+        complete_lookup.callback(None)
+        self.failureResultOf(d1, CancelledError)
+        self.assertEqual(d2.result, "123")
+
+    def test_cancel_logcontexts(self):
+        """Test that cancellation does not break logcontexts.
+
+        * The `CancelledError` must be raised with the correct logcontext.
+        * The inner lookup must not resume with a finished logcontext.
+        * The inner lookup must not restore a finished logcontext when done.
+        """
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            inner_context_was_finished = False
+
+            @cached()
+            async def fn(self, arg1):
+                await make_deferred_yieldable(complete_lookup)
+                self.inner_context_was_finished = current_context().finished
+                return str(arg1)
+
+        obj = Cls()
+
+        async def do_lookup():
+            with LoggingContext("c1") as c1:
+                try:
+                    await obj.fn(123)
+                    self.fail("No CancelledError thrown")
+                except CancelledError:
+                    self.assertEqual(
+                        current_context(),
+                        c1,
+                        "CancelledError was not raised with the correct logcontext",
+                    )
+                    # suppress the error and succeed
+
+        d = defer.ensureDeferred(do_lookup())
+        d.cancel()
+
+        complete_lookup.callback(None)
+        self.successResultOf(d)
+        self.assertFalse(
+            obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+        )
+        self.assertEqual(current_context(), SENTINEL_CONTEXT)
+
 
 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     """More tests for @cached
@@ -656,7 +802,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
                 assert current_context().name == "c1"
                 # we want this to behave like an asynchronous function
@@ -715,7 +861,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             def list_fn(self, args1) -> "Deferred[dict]":
                 return self.mock(args1)
 
@@ -758,7 +904,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1")
+            @descriptors.cachedList(cached_method_name="fn", list_name="args1")
             async def list_fn(self, args1, arg2):
                 # we want this to behave like an asynchronous function
                 await run_on_reactor()
@@ -787,3 +933,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         obj.fn.invalidate((10, 2))
         invalidate0.assert_called_once()
         invalidate1.assert_called_once()
+
+    def test_cancel(self):
+        """Test that cancelling a lookup does not cancel other lookups"""
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            @cached()
+            def fn(self, arg1):
+                pass
+
+            @cachedList(cached_method_name="fn", list_name="args")
+            async def list_fn(self, args):
+                await complete_lookup
+                return {arg: str(arg) for arg in args}
+
+        obj = Cls()
+
+        d1 = obj.list_fn([123, 456])
+        d2 = obj.list_fn([123, 456, 789])
+        self.assertFalse(d1.called)
+        self.assertFalse(d2.called)
+
+        d1.cancel()
+
+        # `d2` should complete normally.
+        complete_lookup.callback(None)
+        self.failureResultOf(d1, CancelledError)
+        self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
+
+    def test_cancel_logcontexts(self):
+        """Test that cancellation does not break logcontexts.
+
+        * The `CancelledError` must be raised with the correct logcontext.
+        * The inner lookup must not resume with a finished logcontext.
+        * The inner lookup must not restore a finished logcontext when done.
+        """
+        complete_lookup: "Deferred[None]" = Deferred()
+
+        class Cls:
+            inner_context_was_finished = False
+
+            @cached()
+            def fn(self, arg1):
+                pass
+
+            @cachedList(cached_method_name="fn", list_name="args")
+            async def list_fn(self, args):
+                await make_deferred_yieldable(complete_lookup)
+                self.inner_context_was_finished = current_context().finished
+                return {arg: str(arg) for arg in args}
+
+        obj = Cls()
+
+        async def do_lookup():
+            with LoggingContext("c1") as c1:
+                try:
+                    await obj.list_fn([123])
+                    self.fail("No CancelledError thrown")
+                except CancelledError:
+                    self.assertEqual(
+                        current_context(),
+                        c1,
+                        "CancelledError was not raised with the correct logcontext",
+                    )
+                    # suppress the error and succeed
+
+        d = defer.ensureDeferred(do_lookup())
+        d.cancel()
+
+        complete_lookup.callback(None)
+        self.successResultOf(d)
+        self.assertFalse(
+            obj.inner_context_was_finished, "Tried to restart a finished logcontext"
+        )
+        self.assertEqual(current_context(), SENTINEL_CONTEXT)
diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py
index 362014f4cb..e5bc416de1 100644
--- a/tests/util/test_async_helpers.py
+++ b/tests/util/test_async_helpers.py
@@ -13,6 +13,8 @@
 # limitations under the License.
 import traceback
 
+from parameterized import parameterized_class
+
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
 from twisted.internet.task import Clock
@@ -23,10 +25,12 @@ from synapse.logging.context import (
     LoggingContext,
     PreserveLoggingContext,
     current_context,
+    make_deferred_yieldable,
 )
 from synapse.util.async_helpers import (
     ObservableDeferred,
     concurrently_execute,
+    delay_cancellation,
     stop_cancellation,
     timeout_deferred,
 )
@@ -100,6 +104,34 @@ class ObservableDeferredTest(TestCase):
         self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
         self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
 
+    def test_cancellation(self):
+        """Test that cancelling an observer does not affect other observers."""
+        origin_d: "Deferred[int]" = Deferred()
+        observable = ObservableDeferred(origin_d, consumeErrors=True)
+
+        observer1 = observable.observe()
+        observer2 = observable.observe()
+        observer3 = observable.observe()
+
+        self.assertFalse(observer1.called)
+        self.assertFalse(observer2.called)
+        self.assertFalse(observer3.called)
+
+        # cancel the second observer
+        observer2.cancel()
+        self.assertFalse(observer1.called)
+        self.failureResultOf(observer2, CancelledError)
+        self.assertFalse(observer3.called)
+
+        # other observers resolve as normal
+        origin_d.callback(123)
+        self.assertEqual(observer1.result, 123, "observer 1 callback result")
+        self.assertEqual(observer3.result, 123, "observer 3 callback result")
+
+        # additional observers resolve as normal
+        observer4 = observable.observe()
+        self.assertEqual(observer4.result, 123, "observer 4 callback result")
+
 
 class TimeoutDeferredTest(TestCase):
     def setUp(self):
@@ -285,13 +317,27 @@ class ConcurrentlyExecuteTest(TestCase):
         self.successResultOf(d2)
 
 
-class StopCancellationTests(TestCase):
-    """Tests for the `stop_cancellation` function."""
+@parameterized_class(
+    ("wrapper",),
+    [("stop_cancellation",), ("delay_cancellation",)],
+)
+class CancellationWrapperTests(TestCase):
+    """Common tests for the `stop_cancellation` and `delay_cancellation` functions."""
+
+    wrapper: str
+
+    def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]":
+        if self.wrapper == "stop_cancellation":
+            return stop_cancellation(deferred)
+        elif self.wrapper == "delay_cancellation":
+            return delay_cancellation(deferred)
+        else:
+            raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
 
     def test_succeed(self):
         """Test that the new `Deferred` receives the result."""
         deferred: "Deferred[str]" = Deferred()
-        wrapper_deferred = stop_cancellation(deferred)
+        wrapper_deferred = self.wrap_deferred(deferred)
 
         # Success should propagate through.
         deferred.callback("success")
@@ -301,7 +347,7 @@ class StopCancellationTests(TestCase):
     def test_failure(self):
         """Test that the new `Deferred` receives the `Failure`."""
         deferred: "Deferred[str]" = Deferred()
-        wrapper_deferred = stop_cancellation(deferred)
+        wrapper_deferred = self.wrap_deferred(deferred)
 
         # Failure should propagate through.
         deferred.errback(ValueError("abc"))
@@ -309,6 +355,10 @@ class StopCancellationTests(TestCase):
         self.failureResultOf(wrapper_deferred, ValueError)
         self.assertIsNone(deferred.result, "`Failure` was not consumed")
 
+
+class StopCancellationTests(TestCase):
+    """Tests for the `stop_cancellation` function."""
+
     def test_cancellation(self):
         """Test that cancellation of the new `Deferred` leaves the original running."""
         deferred: "Deferred[str]" = Deferred()
@@ -319,11 +369,101 @@ class StopCancellationTests(TestCase):
         self.assertTrue(wrapper_deferred.called)
         self.failureResultOf(wrapper_deferred, CancelledError)
         self.assertFalse(
-            deferred.called, "Original `Deferred` was unexpectedly cancelled."
+            deferred.called, "Original `Deferred` was unexpectedly cancelled"
+        )
+
+        # Now make the original `Deferred` fail.
+        # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+        # in logs.
+        deferred.errback(ValueError("abc"))
+        self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+
+class DelayCancellationTests(TestCase):
+    """Tests for the `delay_cancellation` function."""
+
+    def test_cancellation(self):
+        """Test that cancellation of the new `Deferred` waits for the original."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = delay_cancellation(deferred)
+
+        # Cancel the new `Deferred`.
+        wrapper_deferred.cancel()
+        self.assertNoResult(wrapper_deferred)
+        self.assertFalse(
+            deferred.called, "Original `Deferred` was unexpectedly cancelled"
+        )
+
+        # Now make the original `Deferred` fail.
+        # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
+        # in logs.
+        deferred.errback(ValueError("abc"))
+        self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+        # Now that the original `Deferred` has failed, we should get a `CancelledError`.
+        self.failureResultOf(wrapper_deferred, CancelledError)
+
+    def test_suppresses_second_cancellation(self):
+        """Test that a second cancellation is suppressed.
+
+        Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
+        """
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = delay_cancellation(deferred)
+
+        # Cancel the new `Deferred`, twice.
+        wrapper_deferred.cancel()
+        wrapper_deferred.cancel()
+        self.assertNoResult(wrapper_deferred)
+        self.assertFalse(
+            deferred.called, "Original `Deferred` was unexpectedly cancelled"
         )
 
-        # Now make the inner `Deferred` fail.
+        # Now make the original `Deferred` fail.
         # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
         # in logs.
         deferred.errback(ValueError("abc"))
         self.assertIsNone(deferred.result, "`Failure` was not consumed")
+
+        # Now that the original `Deferred` has failed, we should get a `CancelledError`.
+        self.failureResultOf(wrapper_deferred, CancelledError)
+
+    def test_propagates_cancelled_error(self):
+        """Test that a `CancelledError` from the original `Deferred` gets propagated."""
+        deferred: "Deferred[str]" = Deferred()
+        wrapper_deferred = delay_cancellation(deferred)
+
+        # Fail the original `Deferred` with a `CancelledError`.
+        cancelled_error = CancelledError()
+        deferred.errback(cancelled_error)
+
+        # The new `Deferred` should fail with exactly the same `CancelledError`.
+        self.assertTrue(wrapper_deferred.called)
+        self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
+
+    def test_preserves_logcontext(self):
+        """Test that logging contexts are preserved."""
+        blocking_d: "Deferred[None]" = Deferred()
+
+        async def inner():
+            await make_deferred_yieldable(blocking_d)
+
+        async def outer():
+            with LoggingContext("c") as c:
+                try:
+                    await delay_cancellation(defer.ensureDeferred(inner()))
+                    self.fail("`CancelledError` was not raised")
+                except CancelledError:
+                    self.assertEqual(c, current_context())
+                    # Succeed with no error, unless the logging context is wrong.
+
+        # Run and block inside `inner()`.
+        d = defer.ensureDeferred(outer())
+        self.assertEqual(SENTINEL_CONTEXT, current_context())
+
+        d.cancel()
+
+        # Now unblock. `outer()` will consume the `CancelledError` and check the
+        # logging context.
+        blocking_d.callback(None)
+        self.successResultOf(d)
diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py
index 3c07252252..5d1aa025d1 100644
--- a/tests/util/test_check_dependencies.py
+++ b/tests/util/test_check_dependencies.py
@@ -12,7 +12,7 @@ from tests.unittest import TestCase
 
 
 class DummyDistribution(metadata.Distribution):
-    def __init__(self, version: str):
+    def __init__(self, version: object):
         self._version = version
 
     @property
@@ -27,7 +27,10 @@ class DummyDistribution(metadata.Distribution):
 
 
 old = DummyDistribution("0.1.2")
+old_release_candidate = DummyDistribution("0.1.2rc3")
 new = DummyDistribution("1.2.3")
+new_release_candidate = DummyDistribution("1.2.3rc4")
+distribution_with_no_version = DummyDistribution(None)
 
 # could probably use stdlib TestCase --- no need for twisted here
 
@@ -65,6 +68,35 @@ class TestDependencyChecker(TestCase):
                 # should not raise
                 check_requirements()
 
+    def test_version_reported_as_none(self) -> None:
+        """Complain if importlib.metadata.version() returns None.
+
+        This shouldn't normally happen, but it was seen in the wild (#12223).
+        """
+        with patch(
+            "synapse.util.check_dependencies.metadata.requires",
+            return_value=["dummypkg >= 1"],
+        ):
+            with self.mock_installed_package(distribution_with_no_version):
+                self.assertRaises(DependencyException, check_requirements)
+
+    def test_checks_ignore_dev_dependencies(self) -> None:
+        """Bot generic and per-extra checks should ignore dev dependencies."""
+        with patch(
+            "synapse.util.check_dependencies.metadata.requires",
+            return_value=["dummypkg >= 1; extra == 'mypy'"],
+        ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
+            # We're testing that none of these calls raise.
+            with self.mock_installed_package(None):
+                check_requirements()
+                check_requirements("cool-extra")
+            with self.mock_installed_package(old):
+                check_requirements()
+                check_requirements("cool-extra")
+            with self.mock_installed_package(new):
+                check_requirements()
+                check_requirements("cool-extra")
+
     def test_generic_check_of_optional_dependency(self) -> None:
         """Complain if an optional package is old."""
         with patch(
@@ -85,11 +117,28 @@ class TestDependencyChecker(TestCase):
         with patch(
             "synapse.util.check_dependencies.metadata.requires",
             return_value=["dummypkg >= 1; extra == 'cool-extra'"],
-        ), patch("synapse.util.check_dependencies.EXTRAS", {"cool-extra"}):
+        ), patch("synapse.util.check_dependencies.RUNTIME_EXTRAS", {"cool-extra"}):
             with self.mock_installed_package(None):
                 self.assertRaises(DependencyException, check_requirements, "cool-extra")
             with self.mock_installed_package(old):
                 self.assertRaises(DependencyException, check_requirements, "cool-extra")
             with self.mock_installed_package(new):
                 # should not raise
+                check_requirements("cool-extra")
+
+    def test_release_candidates_satisfy_dependency(self) -> None:
+        """
+        Tests that release candidates count as far as satisfying a dependency
+        is concerned.
+        (Regression test, see #12176.)
+        """
+        with patch(
+            "synapse.util.check_dependencies.metadata.requires",
+            return_value=["dummypkg >= 1"],
+        ):
+            with self.mock_installed_package(old_release_candidate):
+                self.assertRaises(DependencyException, check_requirements)
+
+            with self.mock_installed_package(new_release_candidate):
+                # should not raise
                 check_requirements()
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index 0774625b85..0c84226197 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -12,8 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import AsyncContextManager, Callable, Sequence, Tuple
+
 from twisted.internet import defer
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import CancelledError, Deferred
 
 from synapse.util.async_helpers import ReadWriteLock
 
@@ -21,87 +23,187 @@ from tests import unittest
 
 
 class ReadWriteLockTestCase(unittest.TestCase):
-    def _assert_called_before_not_after(self, lst, first_false):
-        for i, d in enumerate(lst[:first_false]):
-            self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
+    def _start_reader_or_writer(
+        self,
+        read_or_write: Callable[[str], AsyncContextManager],
+        key: str,
+        return_value: str,
+    ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+        """Starts a reader or writer which acquires the lock, blocks, then completes.
+
+        Args:
+            read_or_write: A function returning a context manager for a lock.
+                Either a bound `ReadWriteLock.read` or `ReadWriteLock.write`.
+            key: The key to read or write.
+            return_value: A string that the reader or writer will resolve with when
+                done.
+
+        Returns:
+            A tuple of three `Deferred`s:
+             * A `Deferred` that resolves with `return_value` once the reader or writer
+               completes successfully.
+             * A `Deferred` that resolves once the reader or writer acquires the lock.
+             * A `Deferred` that blocks the reader or writer. Must be resolved by the
+               caller to allow the reader or writer to release the lock and complete.
+        """
+        acquired_d: "Deferred[None]" = Deferred()
+        unblock_d: "Deferred[None]" = Deferred()
+
+        async def reader_or_writer():
+            async with read_or_write(key):
+                acquired_d.callback(None)
+                await unblock_d
+            return return_value
+
+        d = defer.ensureDeferred(reader_or_writer())
+        return d, acquired_d, unblock_d
+
+    def _start_blocking_reader(
+        self, rwlock: ReadWriteLock, key: str, return_value: str
+    ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+        """Starts a reader which acquires the lock, blocks, then releases the lock.
+
+        See the docstring for `_start_reader_or_writer` for details about the arguments
+        and return values.
+        """
+        return self._start_reader_or_writer(rwlock.read, key, return_value)
+
+    def _start_blocking_writer(
+        self, rwlock: ReadWriteLock, key: str, return_value: str
+    ) -> Tuple["Deferred[str]", "Deferred[None]", "Deferred[None]"]:
+        """Starts a writer which acquires the lock, blocks, then releases the lock.
+
+        See the docstring for `_start_reader_or_writer` for details about the arguments
+        and return values.
+        """
+        return self._start_reader_or_writer(rwlock.write, key, return_value)
+
+    def _start_nonblocking_reader(
+        self, rwlock: ReadWriteLock, key: str, return_value: str
+    ) -> Tuple["Deferred[str]", "Deferred[None]"]:
+        """Starts a reader which acquires the lock, then releases it immediately.
+
+        See the docstring for `_start_reader_or_writer` for details about the arguments.
+
+        Returns:
+            A tuple of two `Deferred`s:
+             * A `Deferred` that resolves with `return_value` once the reader completes
+               successfully.
+             * A `Deferred` that resolves once the reader acquires the lock.
+        """
+        d, acquired_d, unblock_d = self._start_reader_or_writer(
+            rwlock.read, key, return_value
+        )
+        unblock_d.callback(None)
+        return d, acquired_d
+
+    def _start_nonblocking_writer(
+        self, rwlock: ReadWriteLock, key: str, return_value: str
+    ) -> Tuple["Deferred[str]", "Deferred[None]"]:
+        """Starts a writer which acquires the lock, then releases it immediately.
+
+        See the docstring for `_start_reader_or_writer` for details about the arguments.
+
+        Returns:
+            A tuple of two `Deferred`s:
+             * A `Deferred` that resolves with `return_value` once the writer completes
+               successfully.
+             * A `Deferred` that resolves once the writer acquires the lock.
+        """
+        d, acquired_d, unblock_d = self._start_reader_or_writer(
+            rwlock.write, key, return_value
+        )
+        unblock_d.callback(None)
+        return d, acquired_d
+
+    def _assert_first_n_resolved(
+        self, deferreds: Sequence["defer.Deferred[None]"], n: int
+    ) -> None:
+        """Assert that exactly the first n `Deferred`s in the given list are resolved.
 
-        for i, d in enumerate(lst[first_false:]):
+        Args:
+            deferreds: The list of `Deferred`s to be checked.
+            n: The number of `Deferred`s at the start of `deferreds` that should be
+                resolved.
+        """
+        for i, d in enumerate(deferreds[:n]):
+            self.assertTrue(d.called, msg="deferred %d was unexpectedly unresolved" % i)
+
+        for i, d in enumerate(deferreds[n:]):
             self.assertFalse(
-                d.called, msg="%d was unexpectedly true" % (i + first_false)
+                d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
             )
 
     def test_rwlock(self):
         rwlock = ReadWriteLock()
-
-        key = object()
+        key = "key"
 
         ds = [
-            rwlock.read(key),  # 0
-            rwlock.read(key),  # 1
-            rwlock.write(key),  # 2
-            rwlock.write(key),  # 3
-            rwlock.read(key),  # 4
-            rwlock.read(key),  # 5
-            rwlock.write(key),  # 6
+            self._start_blocking_reader(rwlock, key, "0"),
+            self._start_blocking_reader(rwlock, key, "1"),
+            self._start_blocking_writer(rwlock, key, "2"),
+            self._start_blocking_writer(rwlock, key, "3"),
+            self._start_blocking_reader(rwlock, key, "4"),
+            self._start_blocking_reader(rwlock, key, "5"),
+            self._start_blocking_writer(rwlock, key, "6"),
         ]
-        ds = [defer.ensureDeferred(d) for d in ds]
+        # `Deferred`s that resolve when each reader or writer acquires the lock.
+        acquired_ds = [acquired_d for _, acquired_d, _ in ds]
+        # `Deferred`s that will trigger the release of locks when resolved.
+        release_ds = [release_d for _, _, release_d in ds]
 
-        self._assert_called_before_not_after(ds, 2)
+        # The first two readers should acquire their locks.
+        self._assert_first_n_resolved(acquired_ds, 2)
 
-        with ds[0].result:
-            self._assert_called_before_not_after(ds, 2)
-        self._assert_called_before_not_after(ds, 2)
+        # Release one of the read locks. The next writer should not acquire the lock,
+        # because there is another reader holding the lock.
+        self._assert_first_n_resolved(acquired_ds, 2)
+        release_ds[0].callback(None)
+        self._assert_first_n_resolved(acquired_ds, 2)
 
-        with ds[1].result:
-            self._assert_called_before_not_after(ds, 2)
-        self._assert_called_before_not_after(ds, 3)
+        # Release the other read lock. The next writer should acquire the lock.
+        self._assert_first_n_resolved(acquired_ds, 2)
+        release_ds[1].callback(None)
+        self._assert_first_n_resolved(acquired_ds, 3)
 
-        with ds[2].result:
-            self._assert_called_before_not_after(ds, 3)
-        self._assert_called_before_not_after(ds, 4)
+        # Release the write lock. The next writer should acquire the lock.
+        self._assert_first_n_resolved(acquired_ds, 3)
+        release_ds[2].callback(None)
+        self._assert_first_n_resolved(acquired_ds, 4)
 
-        with ds[3].result:
-            self._assert_called_before_not_after(ds, 4)
-        self._assert_called_before_not_after(ds, 6)
+        # Release the write lock. The next two readers should acquire locks.
+        self._assert_first_n_resolved(acquired_ds, 4)
+        release_ds[3].callback(None)
+        self._assert_first_n_resolved(acquired_ds, 6)
 
-        with ds[5].result:
-            self._assert_called_before_not_after(ds, 6)
-        self._assert_called_before_not_after(ds, 6)
+        # Release one of the read locks. The next writer should not acquire the lock,
+        # because there is another reader holding the lock.
+        self._assert_first_n_resolved(acquired_ds, 6)
+        release_ds[5].callback(None)
+        self._assert_first_n_resolved(acquired_ds, 6)
 
-        with ds[4].result:
-            self._assert_called_before_not_after(ds, 6)
-        self._assert_called_before_not_after(ds, 7)
+        # Release the other read lock. The next writer should acquire the lock.
+        self._assert_first_n_resolved(acquired_ds, 6)
+        release_ds[4].callback(None)
+        self._assert_first_n_resolved(acquired_ds, 7)
 
-        with ds[6].result:
-            pass
+        # Release the write lock.
+        release_ds[6].callback(None)
 
-        d = defer.ensureDeferred(rwlock.write(key))
-        self.assertTrue(d.called)
-        with d.result:
-            pass
+        # Acquire and release the write and read locks one last time for good measure.
+        _, acquired_d = self._start_nonblocking_writer(rwlock, key, "last writer")
+        self.assertTrue(acquired_d.called)
 
-        d = defer.ensureDeferred(rwlock.read(key))
-        self.assertTrue(d.called)
-        with d.result:
-            pass
+        _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
+        self.assertTrue(acquired_d.called)
 
     def test_lock_handoff_to_nonblocking_writer(self):
         """Test a writer handing the lock to another writer that completes instantly."""
         rwlock = ReadWriteLock()
         key = "key"
 
-        unblock: "Deferred[None]" = Deferred()
-
-        async def blocking_write():
-            with await rwlock.write(key):
-                await unblock
-
-        async def nonblocking_write():
-            with await rwlock.write(key):
-                pass
-
-        d1 = defer.ensureDeferred(blocking_write())
-        d2 = defer.ensureDeferred(nonblocking_write())
+        d1, _, unblock = self._start_blocking_writer(rwlock, key, "write 1 completed")
+        d2, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
         self.assertFalse(d1.called)
         self.assertFalse(d2.called)
 
@@ -111,5 +213,182 @@ class ReadWriteLockTestCase(unittest.TestCase):
         self.assertTrue(d2.called)
 
         # The `ReadWriteLock` should operate as normal.
-        d3 = defer.ensureDeferred(nonblocking_write())
+        d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
         self.assertTrue(d3.called)
+
+    def test_cancellation_while_holding_read_lock(self):
+        """Test cancellation while holding a read lock.
+
+        A waiting writer should be given the lock when the reader holding the lock is
+        cancelled.
+        """
+        rwlock = ReadWriteLock()
+        key = "key"
+
+        # 1. A reader takes the lock and blocks.
+        reader_d, _, _ = self._start_blocking_reader(rwlock, key, "read completed")
+
+        # 2. A writer waits for the reader to complete.
+        writer_d, _ = self._start_nonblocking_writer(rwlock, key, "write completed")
+        self.assertFalse(writer_d.called)
+
+        # 3. The reader is cancelled.
+        reader_d.cancel()
+        self.failureResultOf(reader_d, CancelledError)
+
+        # 4. The writer should take the lock and complete.
+        self.assertTrue(
+            writer_d.called, "Writer is stuck waiting for a cancelled reader"
+        )
+        self.assertEqual("write completed", self.successResultOf(writer_d))
+
+    def test_cancellation_while_holding_write_lock(self):
+        """Test cancellation while holding a write lock.
+
+        A waiting reader should be given the lock when the writer holding the lock is
+        cancelled.
+        """
+        rwlock = ReadWriteLock()
+        key = "key"
+
+        # 1. A writer takes the lock and blocks.
+        writer_d, _, _ = self._start_blocking_writer(rwlock, key, "write completed")
+
+        # 2. A reader waits for the writer to complete.
+        reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
+        self.assertFalse(reader_d.called)
+
+        # 3. The writer is cancelled.
+        writer_d.cancel()
+        self.failureResultOf(writer_d, CancelledError)
+
+        # 4. The reader should take the lock and complete.
+        self.assertTrue(
+            reader_d.called, "Reader is stuck waiting for a cancelled writer"
+        )
+        self.assertEqual("read completed", self.successResultOf(reader_d))
+
+    def test_cancellation_while_waiting_for_read_lock(self):
+        """Test cancellation while waiting for a read lock.
+
+        Tests that cancelling a waiting reader:
+         * does not cancel the writer it is waiting on
+         * does not cancel the next writer waiting on it
+         * does not allow the next writer to acquire the lock before an earlier writer
+           has finished
+         * does not keep the next writer waiting indefinitely
+
+        These correspond to the asserts with explicit messages.
+        """
+        rwlock = ReadWriteLock()
+        key = "key"
+
+        # 1. A writer takes the lock and blocks.
+        writer1_d, _, unblock_writer1 = self._start_blocking_writer(
+            rwlock, key, "write 1 completed"
+        )
+
+        # 2. A reader waits for the first writer to complete.
+        #    This reader will be cancelled later.
+        reader_d, _ = self._start_nonblocking_reader(rwlock, key, "read completed")
+        self.assertFalse(reader_d.called)
+
+        # 3. A second writer waits for both the first writer and the reader to complete.
+        writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
+        self.assertFalse(writer2_d.called)
+
+        # 4. The waiting reader is cancelled.
+        #    Neither of the writers should be cancelled.
+        #    The second writer should still be waiting, but only on the first writer.
+        reader_d.cancel()
+        self.failureResultOf(reader_d, CancelledError)
+        self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
+        self.assertFalse(
+            writer2_d.called,
+            "Second writer was unexpectedly cancelled or given the lock before the "
+            "first writer finished",
+        )
+
+        # 5. Unblock the first writer, which should complete.
+        unblock_writer1.callback(None)
+        self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
+
+        # 6. The second writer should take the lock and complete.
+        self.assertTrue(
+            writer2_d.called, "Second writer is stuck waiting for a cancelled reader"
+        )
+        self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
+
+    def test_cancellation_while_waiting_for_write_lock(self):
+        """Test cancellation while waiting for a write lock.
+
+        Tests that cancelling a waiting writer:
+         * does not cancel the reader or writer it is waiting on
+         * does not cancel the next writer waiting on it
+         * does not allow the next writer to acquire the lock before an earlier reader
+           and writer have finished
+         * does not keep the next writer waiting indefinitely
+
+        These correspond to the asserts with explicit messages.
+        """
+        rwlock = ReadWriteLock()
+        key = "key"
+
+        # 1. A reader takes the lock and blocks.
+        reader_d, _, unblock_reader = self._start_blocking_reader(
+            rwlock, key, "read completed"
+        )
+
+        # 2. A writer waits for the reader to complete.
+        writer1_d, _, unblock_writer1 = self._start_blocking_writer(
+            rwlock, key, "write 1 completed"
+        )
+
+        # 3. A second writer waits for both the reader and first writer to complete.
+        #    This writer will be cancelled later.
+        writer2_d, _ = self._start_nonblocking_writer(rwlock, key, "write 2 completed")
+        self.assertFalse(writer2_d.called)
+
+        # 4. A third writer waits for the second writer to complete.
+        writer3_d, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
+        self.assertFalse(writer3_d.called)
+
+        # 5. The second writer is cancelled, but continues waiting for the lock.
+        #    The reader, first writer and third writer should not be cancelled.
+        #    The first writer should still be waiting on the reader.
+        #    The third writer should still be waiting on the second writer.
+        writer2_d.cancel()
+        self.assertNoResult(writer2_d)
+        self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled")
+        self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
+        self.assertFalse(
+            writer3_d.called,
+            "Third writer was unexpectedly cancelled or given the lock before the first "
+            "writer finished",
+        )
+
+        # 6. Unblock the reader, which should complete.
+        #    The first writer should be given the lock and block.
+        #    The third writer should still be waiting on the second writer.
+        unblock_reader.callback(None)
+        self.assertEqual("read completed", self.successResultOf(reader_d))
+        self.assertNoResult(writer2_d)
+        self.assertFalse(
+            writer3_d.called,
+            "Third writer was unexpectedly given the lock before the first writer "
+            "finished",
+        )
+
+        # 7. Unblock the first writer, which should complete.
+        unblock_writer1.callback(None)
+        self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
+
+        # 8. The second writer should take the lock and release it immediately, since it
+        #    has been cancelled.
+        self.failureResultOf(writer2_d, CancelledError)
+
+        # 9. The third writer should take the lock and complete.
+        self.assertTrue(
+            writer3_d.called, "Third writer is stuck waiting for a cancelled writer"
+        )
+        self.assertEqual("write 3 completed", self.successResultOf(writer3_d))
diff --git a/tests/utils.py b/tests/utils.py
index ef99c72e0b..f6b1d60371 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,13 +15,8 @@
 
 import atexit
 import os
-from unittest.mock import Mock, patch
-from urllib import parse as urlparse
-
-from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -187,111 +182,6 @@ def mock_getRawHeaders(headers=None):
     return getRawHeaders
 
 
-# This is a mock /resource/ not an entire server
-class MockHttpResource:
-    def __init__(self, prefix=""):
-        self.callbacks = []  # 3-tuple of method/pattern/function
-        self.prefix = prefix
-
-    def trigger_get(self, path):
-        return self.trigger(b"GET", path, None)
-
-    @patch("twisted.web.http.Request")
-    @defer.inlineCallbacks
-    def trigger(
-        self, http_method, path, content, mock_request, federation_auth_origin=None
-    ):
-        """Fire an HTTP event.
-
-        Args:
-            http_method : The HTTP method
-            path : The HTTP path
-            content : The HTTP body
-            mock_request : Mocked request to pass to the event so it can get
-                           content.
-            federation_auth_origin (bytes|None): domain to authenticate as, for federation
-        Returns:
-            A tuple of (code, response)
-        Raises:
-            KeyError If no event is found which will handle the path.
-        """
-        path = self.prefix + path
-
-        # annoyingly we return a twisted http request which has chained calls
-        # to get at the http content, hence mock it here.
-        mock_content = Mock()
-        config = {"read.return_value": content}
-        mock_content.configure_mock(**config)
-        mock_request.content = mock_content
-
-        mock_request.method = http_method.encode("ascii")
-        mock_request.uri = path.encode("ascii")
-
-        mock_request.getClientIP.return_value = "-"
-
-        headers = {}
-        if federation_auth_origin is not None:
-            headers[b"Authorization"] = [
-                b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
-            ]
-        mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
-
-        # return the right path if the event requires it
-        mock_request.path = path
-
-        # add in query params to the right place
-        try:
-            mock_request.args = urlparse.parse_qs(path.split("?")[1])
-            mock_request.path = path.split("?")[0]
-            path = mock_request.path
-        except Exception:
-            pass
-
-        if isinstance(path, bytes):
-            path = path.decode("utf8")
-
-        for (method, pattern, func) in self.callbacks:
-            if http_method != method:
-                continue
-
-            matcher = pattern.match(path)
-            if matcher:
-                try:
-                    args = [urlparse.unquote(u) for u in matcher.groups()]
-
-                    (code, response) = yield defer.ensureDeferred(
-                        func(mock_request, *args)
-                    )
-                    return code, response
-                except CodeMessageException as e:
-                    return e.code, cs_error(e.msg, code=e.errcode)
-
-        raise KeyError("No event can handle %s" % path)
-
-    def register_paths(self, method, path_patterns, callback, servlet_name):
-        for path_pattern in path_patterns:
-            self.callbacks.append((method, path_pattern, callback))
-
-
-class MockKey:
-    alg = "mock_alg"
-    version = "mock_version"
-    signature = b"\x9a\x87$"
-
-    @property
-    def verify_key(self):
-        return self
-
-    def sign(self, message):
-        return self
-
-    def verify(self, message, sig):
-        assert sig == b"\x9a\x87$"
-
-    def encode(self):
-        return b"<fake_encoded_key>"
-
-
 class MockClock:
     now = 1000