summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/app/test_openid_listener.py8
-rw-r--r--tests/appservice/test_scheduler.py6
-rw-r--r--tests/crypto/test_keyring.py22
-rw-r--r--tests/events/test_presence_router.py22
-rw-r--r--tests/federation/test_complexity.py55
-rw-r--r--tests/federation/test_federation_catch_up.py84
-rw-r--r--tests/federation/test_federation_client.py16
-rw-r--r--tests/federation/test_federation_sender.py127
-rw-r--r--tests/federation/test_federation_server.py19
-rw-r--r--tests/federation/transport/server/test__base.py4
-rw-r--r--tests/federation/transport/test_knocking.py32
-rw-r--r--tests/federation/transport/test_server.py6
-rw-r--r--tests/handlers/test_admin.py56
-rw-r--r--tests/handlers/test_appservice.py2
-rw-r--r--tests/handlers/test_cas.py8
-rw-r--r--tests/handlers/test_e2e_keys.py55
-rw-r--r--tests/handlers/test_federation.py58
-rw-r--r--tests/handlers/test_federation_event.py6
-rw-r--r--tests/handlers/test_message.py36
-rw-r--r--tests/handlers/test_oidc.py4
-rw-r--r--tests/handlers/test_password_providers.py2
-rw-r--r--tests/handlers/test_register.py17
-rw-r--r--tests/handlers/test_saml.py14
-rw-r--r--tests/handlers/test_sso.py1
-rw-r--r--tests/handlers/test_stats.py1
-rw-r--r--tests/handlers/test_typing.py12
-rw-r--r--tests/handlers/test_user_directory.py42
-rw-r--r--tests/http/__init__.py19
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py152
-rw-r--r--tests/http/federation/test_srv_resolver.py61
-rw-r--r--tests/http/server/_base.py2
-rw-r--r--tests/http/test_additional_resource.py18
-rw-r--r--tests/http/test_client.py41
-rw-r--r--tests/http/test_endpoint.py4
-rw-r--r--tests/http/test_matrixfederationclient.py53
-rw-r--r--tests/http/test_proxyagent.py134
-rw-r--r--tests/http/test_servlet.py8
-rw-r--r--tests/http/test_simple_client.py14
-rw-r--r--tests/http/test_site.py8
-rw-r--r--tests/logging/test_remote_handler.py17
-rw-r--r--tests/media/__init__.py (renamed from tests/rest/media/v1/__init__.py)2
-rw-r--r--tests/media/test_base.py (renamed from tests/rest/media/v1/test_base.py)2
-rw-r--r--tests/media/test_filepath.py (renamed from tests/rest/media/v1/test_filepath.py)2
-rw-r--r--tests/media/test_html_preview.py (renamed from tests/rest/media/v1/test_html_preview.py)2
-rw-r--r--tests/media/test_media_storage.py (renamed from tests/rest/media/v1/test_media_storage.py)62
-rw-r--r--tests/media/test_oembed.py (renamed from tests/rest/media/v1/test_oembed.py)2
-rw-r--r--tests/module_api/test_api.py136
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py14
-rw-r--r--tests/push/test_email.py59
-rw-r--r--tests/push/test_http.py45
-rw-r--r--tests/push/test_push_rule_evaluator.py328
-rw-r--r--tests/replication/_base.py70
-rw-r--r--tests/replication/http/test__base.py2
-rw-r--r--tests/replication/slave/storage/_base.py25
-rw-r--r--tests/replication/slave/storage/test_events.py86
-rw-r--r--tests/replication/tcp/streams/test_account_data.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py26
-rw-r--r--tests/replication/tcp/streams/test_federation.py2
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py4
-rw-r--r--tests/replication/tcp/streams/test_typing.py37
-rw-r--r--tests/replication/tcp/test_commands.py6
-rw-r--r--tests/replication/tcp/test_handler.py62
-rw-r--r--tests/replication/tcp/test_remote_server_up.py8
-rw-r--r--tests/replication/test_auth.py14
-rw-r--r--tests/replication/test_client_reader_shard.py4
-rw-r--r--tests/replication/test_federation_ack.py12
-rw-r--r--tests/replication/test_federation_sender_shard.py12
-rw-r--r--tests/replication/test_module_cache_invalidation.py2
-rw-r--r--tests/replication/test_multi_media_repo.py16
-rw-r--r--tests/replication/test_pusher_shard.py12
-rw-r--r--tests/replication/test_sharded_event_persister.py14
-rw-r--r--tests/rest/admin/test_device.py3
-rw-r--r--tests/rest/admin/test_event_reports.py143
-rw-r--r--tests/rest/admin/test_media.py16
-rw-r--r--tests/rest/admin/test_room.py1
-rw-r--r--tests/rest/admin/test_server_notice.py5
-rw-r--r--tests/rest/admin/test_user.py11
-rw-r--r--tests/rest/admin/test_username_available.py15
-rw-r--r--tests/rest/client/test_account.py6
-rw-r--r--tests/rest/client/test_auth.py19
-rw-r--r--tests/rest/client/test_capabilities.py1
-rw-r--r--tests/rest/client/test_consent.py1
-rw-r--r--tests/rest/client/test_directory.py1
-rw-r--r--tests/rest/client/test_ephemeral_message.py1
-rw-r--r--tests/rest/client/test_events.py3
-rw-r--r--tests/rest/client/test_filter.py5
-rw-r--r--tests/rest/client/test_keys.py141
-rw-r--r--tests/rest/client/test_login.py2
-rw-r--r--tests/rest/client/test_login_token_request.py1
-rw-r--r--tests/rest/client/test_presence.py11
-rw-r--r--tests/rest/client/test_profile.py3
-rw-r--r--tests/rest/client/test_register.py11
-rw-r--r--tests/rest/client/test_relations.py237
-rw-r--r--tests/rest/client/test_rendezvous.py1
-rw-r--r--tests/rest/client/test_report_event.py12
-rw-r--r--tests/rest/client/test_retention.py6
-rw-r--r--tests/rest/client/test_rooms.py30
-rw-r--r--tests/rest/client/test_shadow_banned.py6
-rw-r--r--tests/rest/client/test_sync.py3
-rw-r--r--tests/rest/client/test_third_party_rules.py126
-rw-r--r--tests/rest/client/test_transactions.py55
-rw-r--r--tests/rest/client/test_upgrade_room.py2
-rw-r--r--tests/rest/client/utils.py58
-rw-r--r--tests/rest/media/test_media_retention.py1
-rw-r--r--tests/rest/media/test_url_preview.py (renamed from tests/rest/media/v1/test_url_preview.py)53
-rw-r--r--tests/scripts/test_new_matrix_user.py25
-rw-r--r--tests/server.py259
-rw-r--r--tests/server_notices/test_consent.py16
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py85
-rw-r--r--tests/storage/databases/main/test_deviceinbox.py1
-rw-r--r--tests/storage/databases/main/test_events_worker.py1
-rw-r--r--tests/storage/databases/main/test_receipts.py2
-rw-r--r--tests/storage/databases/main/test_room.py1
-rw-r--r--tests/storage/test_account_data.py22
-rw-r--r--tests/storage/test_cleanup_extrems.py8
-rw-r--r--tests/storage/test_client_ips.py1
-rw-r--r--tests/storage/test_event_chain.py18
-rw-r--r--tests/storage/test_event_federation.py11
-rw-r--r--tests/storage/test_event_metrics.py3
-rw-r--r--tests/storage/test_event_push_actions.py2
-rw-r--r--tests/storage/test_events.py8
-rw-r--r--tests/storage/test_keys.py61
-rw-r--r--tests/storage/test_purge.py3
-rw-r--r--tests/storage/test_receipts.py16
-rw-r--r--tests/storage/test_redaction.py24
-rw-r--r--tests/storage/test_room_search.py3
-rw-r--r--tests/storage/test_roommember.py3
-rw-r--r--tests/storage/test_state.py160
-rw-r--r--tests/storage/test_stream.py4
-rw-r--r--tests/storage/test_unsafe_locale.py2
-rw-r--r--tests/storage/test_user_directory.py198
-rw-r--r--tests/test_distributor.py12
-rw-r--r--tests/test_event_auth.py32
-rw-r--r--tests/test_federation.py111
-rw-r--r--tests/test_mau.py36
-rw-r--r--tests/test_phone_home.py2
-rw-r--r--tests/test_rust.py2
-rw-r--r--tests/test_state.py198
-rw-r--r--tests/test_terms_auth.py19
-rw-r--r--tests/test_test_utils.py16
-rw-r--r--tests/test_types.py30
-rw-r--r--tests/test_utils/__init__.py26
-rw-r--r--tests/test_utils/event_injection.py15
-rw-r--r--tests/test_utils/html_parsers.py6
-rw-r--r--tests/test_utils/logging_setup.py4
-rw-r--r--tests/test_utils/oidc.py10
-rw-r--r--tests/test_visibility.py27
-rw-r--r--tests/unittest.py21
-rw-r--r--tests/util/test_retryutils.py2
-rw-r--r--tests/utils.py31
150 files changed, 3197 insertions, 1648 deletions
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 5d89ba94ad..2ee343d8a4 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
         }
 
         # Listen with the config
-        self.hs._listen_http(parse_listener_def(0, config))
+        hs = self.hs
+        assert isinstance(hs, GenericWorkerServer)
+        hs._listen_http(parse_listener_def(0, config))
 
         # Grab the resource from the site that was told to listen
         site = self.reactor.tcpServers[0][1]
@@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
         }
 
         # Listen with the config
-        self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
+        hs = self.hs
+        assert isinstance(hs, SynapseHomeServer)
+        hs._listener_http(self.hs.config, parse_listener_def(0, config))
 
         # Grab the resource from the site that was told to listen
         site = self.reactor.tcpServers[0][1]
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index febcc1499d..e2a3bad065 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -11,12 +11,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
+from typing import List, Optional, Sequence, Tuple, cast
 from unittest.mock import Mock
 
 from typing_extensions import TypeAlias
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.appservice import (
     ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
 
 from ..utils import MockClock
 
-if TYPE_CHECKING:
-    from twisted.internet.testing import MemoryReactor
-
 
 class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
     def setUp(self) -> None:
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 56193bc000..d6db5b6423 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -194,7 +194,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         key1 = signedjson.key.generate_signing_key("1")
         r = self.hs.get_datastores().main.store_server_verify_keys(
             "server9",
-            time.time() * 1000,
+            int(time.time() * 1000),
             [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
         )
         self.get_success(r)
@@ -289,7 +289,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         key1 = signedjson.key.generate_signing_key("1")
         r = self.hs.get_datastores().main.store_server_verify_keys(
             "server9",
-            time.time() * 1000,
+            int(time.time() * 1000),
             # None is not a valid value in FetchKeyResult, but we're abusing this
             # API to insert null values into the database. The nulls get converted
             # to 0 when fetched in KeyStore.get_server_verify_keys.
@@ -468,9 +468,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
         )
-        res = key_json[lookup_triplet]
-        self.assertEqual(len(res), 1)
-        res = res[0]
+        res_keys = key_json[lookup_triplet]
+        self.assertEqual(len(res_keys), 1)
+        res = res_keys[0]
         self.assertEqual(res["key_id"], testverifykey_id)
         self.assertEqual(res["from_server"], SERVER_NAME)
         self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -586,9 +586,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
         )
-        res = key_json[lookup_triplet]
-        self.assertEqual(len(res), 1)
-        res = res[0]
+        res_keys = key_json[lookup_triplet]
+        self.assertEqual(len(res_keys), 1)
+        res = res_keys[0]
         self.assertEqual(res["key_id"], testverifykey_id)
         self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
         self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -707,9 +707,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
         )
-        res = key_json[lookup_triplet]
-        self.assertEqual(len(res), 1)
-        res = res[0]
+        res_keys = key_json[lookup_triplet]
+        self.assertEqual(len(res_keys), 1)
+        res = res_keys[0]
         self.assertEqual(res["key_id"], testverifykey_id)
         self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
         self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index a9893def74..6fb1f1bd6e 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -31,7 +31,11 @@ from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
 from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+    FederatingHomeserverTestCase,
+    HomeserverTestCase,
+    override_config,
+)
 
 
 @attr.s
@@ -152,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
-        fed_transport_client = Mock(spec=["send_transaction"])
-        fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client = Mock(spec=["send_transaction"])
+        self.fed_transport_client.send_transaction = simple_async_mock({})
 
         hs = self.setup_test_homeserver(
-            federation_transport_client=fed_transport_client,
+            federation_transport_client=self.fed_transport_client,
         )
 
         load_legacy_presence_router(hs)
@@ -418,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         #
         # Thus we reset the mock, and try sending all online local user
         # presence again
-        self.hs.get_federation_transport_client().send_transaction.reset_mock()
+        self.fed_transport_client.send_transaction.reset_mock()
 
         # Broadcast local user online presence
         self.get_success(
@@ -443,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         }
         found_users = set()
 
-        calls = (
-            self.hs.get_federation_transport_client().send_transaction.call_args_list
-        )
+        calls = self.fed_transport_client.send_transaction.call_args_list
         for call in calls:
             call_args = call[0]
             federation_transaction: Transaction = call_args[0]
@@ -470,7 +472,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
 
 
 def send_presence_update(
-    testcase: FederatingHomeserverTestCase,
+    testcase: HomeserverTestCase,
     user_id: str,
     access_token: str,
     presence_state: str,
@@ -491,7 +493,7 @@ def send_presence_update(
 
 
 def sync_presence(
-    testcase: FederatingHomeserverTestCase,
+    testcase: HomeserverTestCase,
     user_id: str,
     since_token: Optional[StreamToken] = None,
 ) -> Tuple[List[UserPresenceState], StreamToken]:
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9f1115dd23..33af8770fd 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -17,27 +17,25 @@ from unittest.mock import Mock
 from synapse.api.errors import Codes, SynapseError
 from synapse.rest import admin
 from synapse.rest.client import login, room
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         room.register_servlets,
         login.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = super().default_config()
         config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
         return config
 
-    def test_complexity_simple(self):
-
+    def test_complexity_simple(self) -> None:
         u1 = self.register_user("u1", "pass")
         u1_token = self.login("u1", "pass")
 
@@ -56,7 +54,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
 
         # Artificially raise the complexity
         store = self.hs.get_datastores().main
-        store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
+
+        async def get_current_state_event_counts(room_id: str) -> int:
+            return int(500 * 1.23)
+
+        store.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[assignment]
 
         # Get the room complexity again -- make sure it's our artificial value
         channel = self.make_signed_federation_request(
@@ -66,8 +68,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         complexity = channel.json_body["v1"]
         self.assertEqual(complexity, 1.23)
 
-    def test_join_too_large(self):
-
+    def test_join_too_large(self) -> None:
         u1 = self.register_user("u1", "pass")
 
         handler = self.hs.get_room_member_handler()
@@ -75,12 +76,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
 
         # Mock out some things, because we don't want to test the whole join
         fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
-        handler.federation_handler.do_invite_join = Mock(
+        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
-            None,
+            create_requester(u1),
             ["other.example.com"],
             "roomid",
             UserID.from_string(u1),
@@ -95,7 +96,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(f.value.code, 400, f.value)
         self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
-    def test_join_too_large_admin(self):
+    def test_join_too_large_admin(self) -> None:
         # Check whether an admin can join if option "admins_can_join" is undefined,
         # this option defaults to false, so the join should fail.
 
@@ -106,12 +107,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
 
         # Mock out some things, because we don't want to test the whole join
         fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
-        handler.federation_handler.do_invite_join = Mock(
+        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
-            None,
+            create_requester(u1),
             ["other.example.com"],
             "roomid",
             UserID.from_string(u1),
@@ -126,8 +127,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(f.value.code, 400, f.value)
         self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
-    def test_join_too_large_once_joined(self):
-
+    def test_join_too_large_once_joined(self) -> None:
         u1 = self.register_user("u1", "pass")
         u1_token = self.login("u1", "pass")
 
@@ -144,17 +144,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
 
         # Mock out some things, because we don't want to test the whole join
         fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
-        handler.federation_handler.do_invite_join = Mock(
+        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
 
         # Artificially raise the complexity
-        self.hs.get_datastores().main.get_current_state_event_counts = (
-            lambda x: make_awaitable(600)
-        )
+        async def get_current_state_event_counts(room_id: str) -> int:
+            return 600
+
+        self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts  # type: ignore[assignment]
 
         d = handler._remote_join(
-            None,
+            create_requester(u1),
             ["other.example.com"],
             room_1,
             UserID.from_string(u1),
@@ -180,7 +181,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         login.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = super().default_config()
         config["limit_remote_rooms"] = {
             "enabled": True,
@@ -189,7 +190,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         }
         return config
 
-    def test_join_too_large_no_admin(self):
+    def test_join_too_large_no_admin(self) -> None:
         # A user which is not an admin should not be able to join a remote room
         # which is too complex.
 
@@ -200,12 +201,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
 
         # Mock out some things, because we don't want to test the whole join
         fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
-        handler.federation_handler.do_invite_join = Mock(
+        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
-            None,
+            create_requester(u1),
             ["other.example.com"],
             "roomid",
             UserID.from_string(u1),
@@ -220,7 +221,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(f.value.code, 400, f.value)
         self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
-    def test_join_too_large_admin(self):
+    def test_join_too_large_admin(self) -> None:
         # An admin should be able to join rooms where a complexity check fails.
 
         u1 = self.register_user("u1", "pass", admin=True)
@@ -230,12 +231,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
 
         # Mock out some things, because we don't want to test the whole join
         fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
-        handler.federation_handler.do_invite_join = Mock(
+        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
-            None,
+            create_requester(u1),
             ["other.example.com"],
             "roomid",
             UserID.from_string(u1),
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index b8fee72898..6381583c24 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,13 +1,21 @@
-from typing import List, Tuple
+from typing import Callable, List, Optional, Tuple
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.federation.sender import PerDestinationQueue, TransactionManager
-from synapse.federation.units import Edu
+from synapse.federation.sender import (
+    FederationSender,
+    PerDestinationQueue,
+    TransactionManager,
+)
+from synapse.federation.units import Edu, Transaction
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
 from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
 from tests.test_utils import event_injection, make_awaitable
@@ -28,35 +36,47 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.federation_transport_client = Mock(spec=["send_transaction"])
         return self.setup_test_homeserver(
-            federation_transport_client=Mock(spec=["send_transaction"]),
+            federation_transport_client=self.federation_transport_client,
         )
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # stub out get_current_hosts_in_room
-        state_handler = hs.get_state_handler()
+        state_storage_controller = hs.get_storage_controllers().state
 
         # This mock is crucial for destination_rooms to be populated.
-        state_handler.get_current_hosts_in_room = Mock(
-            return_value=make_awaitable(["test", "host2"])
+        # TODO: this seems to no longer be the case---tests pass with this mock
+        # commented out.
+        state_storage_controller.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
+            return_value=make_awaitable({"test", "host2"})
         )
 
         # whenever send_transaction is called, record the pdu data
-        self.pdus = []
-        self.failed_pdus = []
+        self.pdus: List[JsonDict] = []
+        self.failed_pdus: List[JsonDict] = []
         self.is_online = True
-        self.hs.get_federation_transport_client().send_transaction.side_effect = (
+        self.federation_transport_client.send_transaction.side_effect = (
             self.record_transaction
         )
 
+        federation_sender = hs.get_federation_sender()
+        assert isinstance(federation_sender, FederationSender)
+        self.federation_sender = federation_sender
+
     def default_config(self) -> JsonDict:
         config = super().default_config()
         config["federation_sender_instances"] = None
         return config
 
-    async def record_transaction(self, txn, json_cb):
-        if self.is_online:
+    async def record_transaction(
+        self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]]
+    ) -> JsonDict:
+        if json_cb is None:
+            # The tests seem to expect that this method raises in this situation.
+            raise Exception("Blank json_cb")
+        elif self.is_online:
             data = json_cb()
             self.pdus.extend(data["pdus"])
             return {}
@@ -92,7 +112,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         )[0]
         return {"event_id": event_id, "stream_ordering": stream_ordering}
 
-    def test_catch_up_destination_rooms_tracking(self):
+    def test_catch_up_destination_rooms_tracking(self) -> None:
         """
         Tests that we populate the `destination_rooms` table as needed.
         """
@@ -117,7 +137,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         self.assertEqual(row_2["event_id"], event_id_2)
         self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
 
-    def test_catch_up_last_successful_stream_ordering_tracking(self):
+    def test_catch_up_last_successful_stream_ordering_tracking(self) -> None:
         """
         Tests that we populate the `destination_rooms` table as needed.
         """
@@ -174,7 +194,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             "Send succeeded but not marked as last_successful_stream_ordering",
         )
 
-    def test_catch_up_from_blank_state(self):
+    def test_catch_up_from_blank_state(self) -> None:
         """
         Runs an overall test of federation catch-up from scratch.
         Further tests will focus on more narrow aspects and edge-cases, but I
@@ -218,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # let's delete the federation transmission queue
         # (this pretends we are starting up fresh.)
         self.assertFalse(
-            self.hs.get_federation_sender()
-            ._per_destination_queues["host2"]
-            .transmission_loop_running
+            self.federation_sender._per_destination_queues[
+                "host2"
+            ].transmission_loop_running
         )
-        del self.hs.get_federation_sender()._per_destination_queues["host2"]
+        del self.federation_sender._per_destination_queues["host2"]
 
         # let's also clear any backoffs
         self.get_success(
@@ -261,16 +281,15 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             destination_tm: str,
             pending_pdus: List[EventBase],
             _pending_edus: List[Edu],
-        ) -> bool:
+        ) -> None:
             assert destination == destination_tm
             results_list.extend(pending_pdus)
-            return True  # success!
 
-        transaction_manager.send_new_transaction = fake_send
+        transaction_manager.send_new_transaction = fake_send  # type: ignore[assignment]
 
         return per_dest_queue, results_list
 
-    def test_catch_up_loop(self):
+    def test_catch_up_loop(self) -> None:
         """
         Tests the behaviour of _catch_up_transmission_loop.
         """
@@ -312,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # also fetch event 5 so we know its last_successful_stream_ordering later
         event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
 
+        assert event_2.internal_metadata.stream_ordering is not None
         self.get_success(
             self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
                 "host2", event_2.internal_metadata.stream_ordering
@@ -334,7 +354,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             event_5.internal_metadata.stream_ordering,
         )
 
-    def test_catch_up_on_synapse_startup(self):
+    def test_catch_up_on_synapse_startup(self) -> None:
         """
         Tests the behaviour of get_catch_up_outstanding_destinations and
             _wake_destinations_needing_catchup.
@@ -412,18 +432,19 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # patch wake_destination to just count the destinations instead
         woken = []
 
-        def wake_destination_track(destination):
+        def wake_destination_track(destination: str) -> None:
             woken.append(destination)
 
-        self.hs.get_federation_sender().wake_destination = wake_destination_track
+        self.federation_sender.wake_destination = wake_destination_track  # type: ignore[assignment]
 
         # cancel the pre-existing timer for _wake_destinations_needing_catchup
         # this is because we are calling it manually rather than waiting for it
         # to be called automatically
-        self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
+        assert self.federation_sender._catchup_after_startup_timer is not None
+        self.federation_sender._catchup_after_startup_timer.cancel()
 
         self.get_success(
-            self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
+            self.federation_sender._wake_destinations_needing_catchup(), by=5.0
         )
 
         # ASSERT (_wake_destinations_needing_catchup):
@@ -432,7 +453,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # - all destinations are woken exactly once; they appear once in woken.
         self.assertCountEqual(woken, server_names[:-1])
 
-    def test_not_latest_event(self):
+    def test_not_latest_event(self) -> None:
         """Test that we send the latest event in the room even if its not ours."""
 
         per_dest_queue, sent_pdus = self.make_fake_destination_queue()
@@ -465,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             )
         )
 
+        assert event_1.internal_metadata.stream_ordering is not None
         self.get_success(
             self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
                 "host2", event_1.internal_metadata.stream_ordering
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index e67f405826..91694e4fca 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -36,7 +36,9 @@ class FederationClientTest(FederatingHomeserverTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         super().prepare(reactor, clock, homeserver)
 
         # mock out the Agent used by the federation client, which is easier than
@@ -51,7 +53,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
         self.creator = f"@creator:{self.OTHER_SERVER_NAME}"
         self.test_room_id = "!room_id"
 
-    def test_get_room_state(self):
+    def test_get_room_state(self) -> None:
         # mock up some events to use in the response.
         # In real life, these would have things in `prev_events` and `auth_events`, but that's
         # a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
@@ -140,7 +142,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
             ["m.room.create", "m.room.member", "m.room.power_levels"],
         )
 
-    def test_get_pdu_returns_nothing_when_event_does_not_exist(self):
+    def test_get_pdu_returns_nothing_when_event_does_not_exist(self) -> None:
         """No event should be returned when the event does not exist"""
         pulled_pdu_info = self.get_success(
             self.hs.get_federation_client().get_pdu(
@@ -151,11 +153,11 @@ class FederationClientTest(FederatingHomeserverTestCase):
         )
         self.assertEqual(pulled_pdu_info, None)
 
-    def test_get_pdu(self):
+    def test_get_pdu(self) -> None:
         """Test to make sure an event is returned by `get_pdu()`"""
         self._get_pdu_once()
 
-    def test_get_pdu_event_from_cache_is_pristine(self):
+    def test_get_pdu_event_from_cache_is_pristine(self) -> None:
         """Test that modifications made to events returned by `get_pdu()`
         do not propagate back to to the internal cache (events returned should
         be a copy).
@@ -176,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
                 RoomVersions.V9,
             )
         )
-        self.assertIsNotNone(pulled_pdu_info2)
+        assert pulled_pdu_info2 is not None
         remote_pdu2 = pulled_pdu_info2.pdu
 
         # Sanity check that we are working against the same event
@@ -224,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
                 RoomVersions.V9,
             )
         )
-        self.assertIsNotNone(pulled_pdu_info)
+        assert pulled_pdu_info is not None
         remote_pdu = pulled_pdu_info.pdu
 
         # check the right call got made to the agent
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 8692d8190f..9e104fd96a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -11,18 +11,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Optional
+from typing import Callable, FrozenSet, List, Optional, Set
 from unittest.mock import Mock
 
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
+from synapse.federation.units import Transaction
+from synapse.handlers.device import DeviceHandler
 from synapse.rest import admin
 from synapse.rest.client import login
+from synapse.server import HomeServer
 from synapse.types import JsonDict, ReadReceipt
+from synapse.util import Clock
 
 from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase
@@ -36,16 +41,17 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
     re-enabled for the main process.
     """
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.federation_transport_client = Mock(spec=["send_transaction"])
         hs = self.setup_test_homeserver(
-            federation_transport_client=Mock(spec=["send_transaction"]),
+            federation_transport_client=self.federation_transport_client,
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
+        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
             return_value=make_awaitable({"test", "host2"})
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[assignment]
             hs.get_storage_controllers().state.get_current_hosts_in_room
         )
 
@@ -56,10 +62,8 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         config["federation_sender_instances"] = None
         return config
 
-    def test_send_receipts(self):
-        mock_send_transaction = (
-            self.hs.get_federation_transport_client().send_transaction
-        )
+    def test_send_receipts(self) -> None:
+        mock_send_transaction = self.federation_transport_client.send_transaction
         mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
@@ -98,10 +102,8 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             ],
         )
 
-    def test_send_receipts_thread(self):
-        mock_send_transaction = (
-            self.hs.get_federation_transport_client().send_transaction
-        )
+    def test_send_receipts_thread(self) -> None:
+        mock_send_transaction = self.federation_transport_client.send_transaction
         mock_send_transaction.return_value = make_awaitable({})
 
         # Create receipts for:
@@ -174,12 +176,10 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             ],
         )
 
-    def test_send_receipts_with_backoff(self):
+    def test_send_receipts_with_backoff(self) -> None:
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
-        mock_send_transaction = (
-            self.hs.get_federation_transport_client().send_transaction
-        )
+        mock_send_transaction = self.federation_transport_client.send_transaction
         mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
@@ -272,51 +272,60 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.federation_transport_client = Mock(
+            spec=["send_transaction", "query_user_devices"]
+        )
         return self.setup_test_homeserver(
-            federation_transport_client=Mock(
-                spec=["send_transaction", "query_user_devices"]
-            ),
+            federation_transport_client=self.federation_transport_client,
         )
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         c = super().default_config()
         # Enable federation sending on the main process.
         c["federation_sender_instances"] = None
         return c
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         test_room_id = "!room:host1"
 
         # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
         # server thinks the user shares a room with `@user2:host2`
-        def get_rooms_for_user(user_id):
-            return defer.succeed({test_room_id})
+        def get_rooms_for_user(user_id: str) -> "defer.Deferred[FrozenSet[str]]":
+            return defer.succeed(frozenset({test_room_id}))
 
-        hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+        hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user  # type: ignore[assignment]
 
-        async def get_current_hosts_in_room(room_id):
+        async def get_current_hosts_in_room(room_id: str) -> Set[str]:
             if room_id == test_room_id:
-                return ["host2"]
+                return {"host2"}
+            else:
+                # TODO: We should fail the test when we encounter an unxpected room ID.
+                # We can't just use `self.fail(...)` here because the app code is greedy
+                # with `Exception` and will catch it before the test can see it.
+                return set()
 
-            # TODO: We should fail the test when we encounter an unxpected room ID.
-            # We can't just use `self.fail(...)` here because the app code is greedy
-            # with `Exception` and will catch it before the test can see it.
+        hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room  # type: ignore[assignment]
 
-        hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
+        device_handler = hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
+        self.device_handler = device_handler
 
         # whenever send_transaction is called, record the edu data
-        self.edus = []
-        self.hs.get_federation_transport_client().send_transaction.side_effect = (
+        self.edus: List[JsonDict] = []
+        self.federation_transport_client.send_transaction.side_effect = (
             self.record_transaction
         )
 
-    def record_transaction(self, txn, json_cb):
+    def record_transaction(
+        self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
+    ) -> "defer.Deferred[JsonDict]":
+        assert json_cb is not None
         data = json_cb()
         self.edus.extend(data["edus"])
         return defer.succeed({})
 
-    def test_send_device_updates(self):
+    def test_send_device_updates(self) -> None:
         """Basic case: each device update should result in an EDU"""
         # create a device
         u1 = self.register_user("user", "pass")
@@ -340,12 +349,12 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.assertEqual(len(self.edus), 1)
         self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
 
-    def test_dont_send_device_updates_for_remote_users(self):
+    def test_dont_send_device_updates_for_remote_users(self) -> None:
         """Check that we don't send device updates for remote users"""
 
         # Send the server a device list EDU for the other user, this will cause
         # it to try and resync the device lists.
-        self.hs.get_federation_transport_client().query_user_devices.return_value = (
+        self.federation_transport_client.query_user_devices.return_value = (
             make_awaitable(
                 {
                     "stream_id": "1",
@@ -356,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         )
 
         self.get_success(
-            self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
+            self.device_handler.device_list_updater.incoming_device_list_update(
                 "host2",
                 {
                     "user_id": "@user2:host2",
@@ -379,7 +388,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         )
         self.assertIn("D1", devices)
 
-    def test_upload_signatures(self):
+    def test_upload_signatures(self) -> None:
         """Uploading signatures on some devices should produce updates for that user"""
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -391,7 +400,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # expect two edus
         self.assertEqual(len(self.edus), 2)
-        stream_id = None
+        stream_id: Optional[int] = None
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
 
@@ -473,13 +482,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
             c = edu["content"]
             if stream_id is not None:
-                self.assertEqual(c["prev_id"], [stream_id])
+                self.assertEqual(c["prev_id"], [stream_id])  # type: ignore[unreachable]
                 self.assertGreaterEqual(c["stream_id"], stream_id)
             stream_id = c["stream_id"]
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2"}, devices)
 
-    def test_delete_devices(self):
+    def test_delete_devices(self) -> None:
         """If devices are deleted, that should result in EDUs too"""
 
         # create devices
@@ -499,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
 
         # delete them again
-        self.get_success(
-            self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
-        )
+        self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
@@ -521,11 +528,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2", "D3"}, devices)
 
-    def test_unreachable_server(self):
+    def test_unreachable_server(self) -> None:
         """If the destination server is unreachable, all the updates should get sent on
         recovery
         """
-        mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+        mock_send_txn = self.federation_transport_client.send_transaction
         mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
 
         # create devices
@@ -535,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.login("user", "pass", device_id="D3")
 
         # delete them again
-        self.get_success(
-            self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
-        )
+        self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
@@ -555,7 +560,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # for each device, there should be a single update
         self.assertEqual(len(self.edus), 3)
-        stream_id = None
+        stream_id: Optional[int] = None
         for edu in self.edus:
             self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
             c = edu["content"]
@@ -566,13 +571,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2", "D3"}, devices)
 
-    def test_prune_outbound_device_pokes1(self):
+    def test_prune_outbound_device_pokes1(self) -> None:
         """If a destination is unreachable, and the updates are pruned, we should get
         a single update.
 
         This case tests the behaviour when the server has never been reachable.
         """
-        mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+        mock_send_txn = self.federation_transport_client.send_transaction
         mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
 
         # create devices
@@ -582,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.login("user", "pass", device_id="D3")
 
         # delete them again
-        self.get_success(
-            self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
-        )
+        self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
 
         # We queue up device list updates to be sent over federation, so we
         # advance to clear the queue.
@@ -615,7 +618,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         # synapse uses an empty prev_id list to indicate "needs a full resync".
         self.assertEqual(c["prev_id"], [])
 
-    def test_prune_outbound_device_pokes2(self):
+    def test_prune_outbound_device_pokes2(self) -> None:
         """If a destination is unreachable, and the updates are pruned, we should get
         a single update.
 
@@ -632,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
 
         # now the server goes offline
-        mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+        mock_send_txn = self.federation_transport_client.send_transaction
         mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
 
         self.login("user", "pass", device_id="D2")
@@ -643,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.reactor.advance(1)
 
         # delete them again
-        self.get_success(
-            self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
-        )
+        self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
 
         self.assertGreaterEqual(mock_send_txn.call_count, 3)
 
@@ -741,7 +742,7 @@ def encode_pubkey(sk: SigningKey) -> str:
     return key.encode_verify_key_base64(key.get_verify_key(sk))
 
 
-def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
+def build_device_dict(user_id: str, device_id: str, sk: SigningKey) -> JsonDict:
     """Build a dict representing the given device"""
     return {
         "user_id": user_id,
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index be719e49c0..6c7738d810 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -21,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.federation.federation_server import server_matches_acl_event
 from synapse.rest import admin
 from synapse.rest.client import login, room
@@ -34,7 +34,6 @@ from tests.unittest import override_config
 
 
 class FederationServerTests(unittest.FederatingHomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         room.register_servlets,
@@ -42,7 +41,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
     ]
 
     @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
-    def test_bad_request(self, query_content):
+    def test_bad_request(self, query_content: bytes) -> None:
         """
         Querying with bad data returns a reasonable error code.
         """
@@ -64,7 +63,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
 
 
 class ServerACLsTestCase(unittest.TestCase):
-    def test_blacklisted_server(self):
+    def test_blacklisted_server(self) -> None:
         e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
         logging.info("ACL event: %s", e.content)
 
@@ -74,7 +73,7 @@ class ServerACLsTestCase(unittest.TestCase):
         self.assertTrue(server_matches_acl_event("evil.com.au", e))
         self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
 
-    def test_block_ip_literals(self):
+    def test_block_ip_literals(self) -> None:
         e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
         logging.info("ACL event: %s", e.content)
 
@@ -83,7 +82,7 @@ class ServerACLsTestCase(unittest.TestCase):
         self.assertFalse(server_matches_acl_event("[1:2::]", e))
         self.assertTrue(server_matches_acl_event("1:2:3:4", e))
 
-    def test_wildcard_matching(self):
+    def test_wildcard_matching(self) -> None:
         e = _create_acl_event({"allow": ["good*.com"]})
         self.assertTrue(
             server_matches_acl_event("good.com", e),
@@ -110,7 +109,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
         login.register_servlets,
     ]
 
-    def test_needs_to_be_in_room(self):
+    def test_needs_to_be_in_room(self) -> None:
         """/v1/state/<room_id> requires the server to be in the room"""
         u1 = self.register_user("u1", "pass")
         u1_token = self.login("u1", "pass")
@@ -131,7 +130,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
 
         self._storage_controllers = hs.get_storage_controllers()
@@ -157,7 +156,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
         return channel.json_body
 
-    def test_send_join(self):
+    def test_send_join(self) -> None:
         """happy-path test of send_join"""
         joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
         join_result = self._make_join(joining_user)
@@ -324,7 +323,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
     #   is probably sufficient to reassure that the bucket is updated.
 
 
-def _create_acl_event(content):
+def _create_acl_event(content: JsonDict) -> EventBase:
     return make_event_from_dict(
         {
             "room_id": "!a:b",
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index e88e5d8bb3..55655de862 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -15,6 +15,8 @@
 from http import HTTPStatus
 from typing import Dict, List, Tuple
 
+from twisted.web.resource import Resource
+
 from synapse.api.errors import Codes
 from synapse.federation.transport.server import BaseFederationServlet
 from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
@@ -62,7 +64,7 @@ class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCa
 
     path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}"
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> Resource:
         """Overrides `HomeserverTestCase.create_test_resource`."""
         resource = JsonResource(self.hs)
 
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index ff589c0b6c..70209ab090 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -12,15 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from collections import OrderedDict
-from typing import Dict, List
+from typing import Any, Dict, List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, JoinRules, Membership
-from synapse.api.room_versions import RoomVersions
-from synapse.events import builder
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase, builder
+from synapse.events.snapshot import EventContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.types import RoomAlias
+from synapse.util import Clock
 
 from tests.test_utils import event_injection
 from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
@@ -197,7 +201,9 @@ class FederationKnockingTestCase(
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.store = homeserver.get_datastores().main
 
         # We're not going to be properly signing events as our remote homeserver is fake,
@@ -205,23 +211,29 @@ class FederationKnockingTestCase(
         # Note that these checks are not relevant to this test case.
 
         # Have this homeserver auto-approve all event signature checking.
-        async def approve_all_signature_checking(_, pdu):
+        async def approve_all_signature_checking(
+            room_version: RoomVersion,
+            pdu: EventBase,
+            record_failure_callback: Any = None,
+        ) -> EventBase:
             return pdu
 
-        homeserver.get_federation_server()._check_sigs_and_hash = (
+        homeserver.get_federation_server()._check_sigs_and_hash = (  # type: ignore[assignment]
             approve_all_signature_checking
         )
 
         # Have this homeserver skip event auth checks. This is necessary due to
         # event auth checks ensuring that events were signed by the sender's homeserver.
-        async def _check_event_auth(origin, event, context, *args, **kwargs):
-            return context
+        async def _check_event_auth(
+            origin: Optional[str], event: EventBase, context: EventContext
+        ) -> None:
+            pass
 
-        homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth
+        homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[assignment]
 
         return super().prepare(reactor, clock, homeserver)
 
-    def test_room_state_returned_when_knocking(self):
+    def test_room_state_returned_when_knocking(self) -> None:
         """
         Tests that specific, stripped state events from a room are returned after
         a remote homeserver successfully knocks on a local room.
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index cfd550a04b..c4231f4aa9 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -20,7 +20,7 @@ from tests.unittest import DEBUG, override_config
 
 class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
     @override_config({"allow_public_rooms_over_federation": False})
-    def test_blocked_public_room_list_over_federation(self):
+    def test_blocked_public_room_list_over_federation(self) -> None:
         """Test that unauthenticated requests to the public rooms directory 403 when
         allow_public_rooms_over_federation is False.
         """
@@ -31,7 +31,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(403, channel.code)
 
     @override_config({"allow_public_rooms_over_federation": True})
-    def test_open_public_room_list_over_federation(self):
+    def test_open_public_room_list_over_federation(self) -> None:
         """Test that unauthenticated requests to the public rooms directory 200 when
         allow_public_rooms_over_federation is True.
         """
@@ -42,7 +42,7 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(200, channel.code)
 
     @DEBUG
-    def test_edu_debugging_doesnt_explode(self):
+    def test_edu_debugging_doesnt_explode(self) -> None:
         """Sanity check incoming federation succeeds with `synapse.debug_8631` enabled.
 
         Remove this when we strip out issue_8631_logger.
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..5569ccef8a 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.room_versions import RoomVersions
 from synapse.rest.client import knock, login, room
 from synapse.server import HomeServer
+from synapse.types import UserID
 from synapse.util import Clock
 
 from tests import unittest
@@ -296,3 +297,58 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(args[0][0]["user_agent"], "user_agent")
         self.assertGreater(args[0][0]["last_seen"], 0)
         self.assertNotIn("access_token", args[0][0])
+
+    def test_account_data(self) -> None:
+        """Tests that user account data get exported."""
+        # add account data
+        self.get_success(
+            self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
+        )
+        self.get_success(
+            self._store.add_account_data_to_room(
+                self.user2, "test_room", "m.per_room", {"b": 2}
+            )
+        )
+
+        writer = Mock()
+
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+        # two calls, one call for user data and one call for room data
+        writer.write_account_data.assert_called()
+
+        args = writer.write_account_data.call_args_list[0][0]
+        self.assertEqual(args[0], "global")
+        self.assertEqual(args[1]["m.global"]["a"], 1)
+
+        args = writer.write_account_data.call_args_list[1][0]
+        self.assertEqual(args[0], "test_room")
+        self.assertEqual(args[1]["m.per_room"]["b"], 2)
+
+    def test_media_ids(self) -> None:
+        """Tests that media's metadata get exported."""
+
+        self.get_success(
+            self._store.store_local_media(
+                media_id="media_1",
+                media_type="image/png",
+                time_now_ms=self.clock.time_msec(),
+                upload_name=None,
+                media_length=50,
+                user_id=UserID.from_string(self.user2),
+            )
+        )
+
+        writer = Mock()
+
+        self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+        writer.write_media_id.assert_called_once()
+
+        args = writer.write_media_id.call_args[0]
+        self.assertEqual(args[0], "media_1")
+        self.assertEqual(args[1]["media_id"], "media_1")
+        self.assertEqual(args[1]["media_length"], 50)
+        self.assertGreater(args[1]["created_ts"], 0)
+        self.assertIsNone(args[1]["upload_name"])
+        self.assertIsNone(args[1]["last_access_ts"])
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7495ab21a..9014e60577 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
 
         # Mock out application services, and allow defining our own in tests
         self._services: List[ApplicationService] = []
-        self.hs.get_datastores().main.get_app_services = Mock(
+        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[assignment]
             return_value=self._services
         )
 
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2733719d82..63aad0d10c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         cas_response = CasResponse("test_user", {})
         request = _mock_request()
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # Map a user via SSO.
         cas_response = CasResponse("test_user", {})
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         cas_response = CasResponse("föö", {})
         request = _mock_request()
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # The response doesn't have the proper userGroup or department.
         cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 95698bc275..6b4cba65d0 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # we should now have an unused alg1 key
-        res = self.get_success(
+        fallback_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, ["alg1"])
+        self.assertEqual(fallback_res, ["alg1"])
 
         # claiming an OTK when no OTKs are available should return the fallback
         # key
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
         # we shouldn't have any unused fallback keys again
-        res = self.get_success(
+        unused_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, [])
+        self.assertEqual(unused_res, [])
 
         # claiming an OTK again should return the same fallback key
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        unused_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, [])
+        self.assertEqual(unused_res, [])
 
         # uploading a new fallback key should result in an unused fallback key
         self.get_success(
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        unused_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, ["alg1"])
+        self.assertEqual(unused_res, ["alg1"])
 
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
         )
 
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
         )
 
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         # upload two device keys, which will be signed later by the self-signing key
-        device_key_1 = {
+        device_key_1: JsonDict = {
             "user_id": local_user,
             "device_id": "abc",
             "algorithms": [
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
             "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
         }
-        device_key_2 = {
+        device_key_2: JsonDict = {
             "user_id": local_user,
             "device_id": "def",
             "algorithms": [
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         }
         self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
+        device_handler = self.hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
         e = self.get_failure(
-            self.hs.get_device_handler().check_device_registered(
+            device_handler.check_device_registered(
                 user_id=local_user,
                 device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
                 initial_device_display_name="new display name",
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         device_id = "xyz"
         # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
         device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
-        device_key = {
+        device_key: JsonDict = {
             "user_id": local_user,
             "device_id": device_id,
             "algorithms": [
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
         master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
-        master_key = {
+        master_key: JsonDict = {
             "user_id": local_user,
             "usage": ["master"],
             "keys": {"ed25519:" + master_pubkey: master_pubkey},
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # the first user
         other_user = "@otherboris:" + self.hs.hostname
         other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
-        other_master_key = {
+        other_master_key: JsonDict = {
             # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
             "user_id": other_user,
             "usage": ["master"],
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_client_keys = mock.Mock(
+        self.hs.get_federation_client().query_client_keys = mock.Mock(  # type: ignore[assignment]
             return_value=make_awaitable(
                 {
                     "device_keys": {remote_user_id: {}},
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_user_devices = mock.Mock(
+        self.hs.get_federation_client().query_user_devices = mock.Mock(  # type: ignore[assignment]
             return_value=make_awaitable(
                 {
                     "user_id": remote_user_id,
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 57675fa407..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         # We mock out the FederationClient.backfill method, to pretend that a remote
         # server has returned our fake event.
         federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
-        self.hs.get_federation_client().backfill = federation_client_backfill_mock
+        self.hs.get_federation_client().backfill = federation_client_backfill_mock  # type: ignore[assignment]
 
         # We also mock the persist method with a side effect of itself. This allows us
         # to track when it has been called while preserving its function.
         persist_events_and_notify_mock = Mock(
             side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
         )
-        self.hs.get_federation_event_handler().persist_events_and_notify = (
+        self.hs.get_federation_event_handler().persist_events_and_notify = (  # type: ignore[assignment]
             persist_events_and_notify_mock
         )
 
@@ -575,26 +575,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
         fed_client = fed_handler.federation_client
 
         room_id = "!room:example.com"
-        membership_event = make_event_from_dict(
-            {
-                "room_id": room_id,
-                "type": "m.room.member",
-                "sender": "@alice:test",
-                "state_key": "@alice:test",
-                "content": {"membership": "join"},
-            },
-            RoomVersions.V10,
-        )
-
-        mock_make_membership_event = Mock(
-            return_value=make_awaitable(
-                (
-                    "example.com",
-                    membership_event,
-                    RoomVersions.V10,
-                )
-            )
-        )
 
         EVENT_CREATE = make_event_from_dict(
             {
@@ -640,6 +620,26 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             },
             room_version=RoomVersions.V10,
         )
+        membership_event = make_event_from_dict(
+            {
+                "room_id": room_id,
+                "type": "m.room.member",
+                "sender": "@alice:test",
+                "state_key": "@alice:test",
+                "content": {"membership": "join"},
+                "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
+            },
+            RoomVersions.V10,
+        )
+        mock_make_membership_event = Mock(
+            return_value=make_awaitable(
+                (
+                    "example.com",
+                    membership_event,
+                    RoomVersions.V10,
+                )
+            )
+        )
         mock_send_join = Mock(
             return_value=make_awaitable(
                 SendJoinResult(
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
         ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
             # Start the partial state sync.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Try to start another partial state sync.
             # Nothing should happen.
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
 
             # The next attempt to start the partial state sync should work.
             is_partial_state = True
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
     def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
         ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
             # Start the partial state sync.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Start the partial state sync again.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
             # Deduplicate another partial state sync.
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
             # Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             self.assertEqual(mock_sync_partial_state_room.call_count, 3)
             mock_sync_partial_state_room.assert_called_with(
                 initial_destination="hs3",
-                other_destinations=["hs2"],
+                other_destinations={"hs2"},
                 room_id="room_id",
             )
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 70ea4d15d4..c067e5bfe3 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
+from synapse.state import StateResolutionStore
 from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         if prev_exists_as_outlier:
             prev_event.internal_metadata.outlier = True
             persistence = self.hs.get_storage_controllers().persistence
+            assert persistence is not None
             self.get_success(
                 persistence.persist_event(
                     prev_event,
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                         bert_member_event.event_id: bert_member_event,
                         rejected_kick_event.event_id: rejected_kick_event,
                     },
-                    state_res_store=main_store,
+                    state_res_store=StateResolutionStore(main_store),
                 )
             ),
             [bert_member_event.event_id, rejected_kick_event.event_id],
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                         rejected_power_levels_event.event_id,
                     ],
                     event_map={},
-                    state_res_store=main_store,
+                    state_res_store=StateResolutionStore(main_store),
                     full_conflicted_set=set(),
                 )
             ),
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..9691d66b48 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = self.hs.get_event_creation_handler()
-        self._persist_event_storage_controller = (
-            self.hs.get_storage_controllers().persistence
-        )
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self._persist_event_storage_controller = persistence
 
         self.user_id = self.register_user("tester", "foobar")
         self.access_token = self.login("tester", "foobar")
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
-        self.info = self.get_success(
+        info = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(
                 self.access_token,
             )
         )
-        self.token_id = self.info.token_id
+        assert info is not None
+        self.token_id = info.token_id
 
         self.requester = create_requester(self.user_id, access_token_id=self.token_id)
 
@@ -78,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         return memberEvent, memberEventContext
 
-    def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+    def _create_duplicate_event(
+        self, txn_id: str
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """Create a new event with the given transaction ID. All events produced
         by this method will be considered duplicates.
         """
@@ -106,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         txn_id = "something_suitably_random"
 
-        event1, context = self._create_duplicate_event(txn_id)
+        event1, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event1))
 
         ret_event1 = self.get_success(
             self.handler.handle_new_client_event(
@@ -118,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(event1.event_id, ret_event1.event_id)
 
-        event2, context = self._create_duplicate_event(txn_id)
+        event2, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event2))
 
         # We want to test that the deduplication at the persit event end works,
         # so we want to make sure we test with different events.
@@ -139,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         # Let's test that calling `persist_event` directly also does the right
         # thing.
-        event3, context = self._create_duplicate_event(txn_id)
+        event3, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event3))
+
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         ret_event3, event_pos3, _ = self.get_success(
@@ -153,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         # Let's test that calling `persist_events` directly also does the right
         # thing.
-        event4, context = self._create_duplicate_event(txn_id)
+        event4, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event4))
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         events, _ = self.get_success(
@@ -173,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         txn_id = "something_else_suitably_random"
 
         # Create two duplicate events to persist at the same time
-        event1, context1 = self._create_duplicate_event(txn_id)
-        event2, context2 = self._create_duplicate_event(txn_id)
+        event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
+        context1 = self.get_success(unpersisted_context1.persist(event1))
+        event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
+        context2 = self.get_success(unpersisted_context2.persist(event2))
 
         # Ensure their event IDs are different to start with
         self.assertNotEqual(event1.event_id, event2.event_id)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index ef5311ce64..bb52b3b1af 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         hs = self.setup_test_homeserver()
         self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
-        self.hs_patcher.start()
+        self.hs_patcher.start()  # type: ignore[attr-defined]
 
         self.handler = hs.get_oidc_handler()
         self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def tearDown(self) -> None:
-        self.hs_patcher.stop()
+        self.hs_patcher.stop()  # type: ignore[attr-defined]
         return super().tearDown()
 
     def reset_mocks(self) -> None:
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 0916de64f5..aa91bc0a3d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             username: The username to use for the test.
             registration: Whether to test with registration URLs.
         """
-        self.hs.get_identity_handler().send_threepid_validation = Mock(
+        self.hs.get_identity_handler().send_threepid_validation = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(0),
         )
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b9332d97dc..aff1ec4758 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker:
         request_info: Collection[Tuple[str, str]],
         auth_provider_id: Optional[str],
     ) -> RegistrationBehaviour:
-        pass
+        return RegistrationBehaviour.ALLOW
 
 
 class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
         username: Optional[str],
         request_info: Collection[Tuple[str, str]],
     ) -> RegistrationBehaviour:
-        pass
+        return RegistrationBehaviour.ALLOW
 
 
 class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_not_blocked(self) -> None:
-        self.store.count_monthly_users = Mock(
+        self.store.count_monthly_users = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
         )
         # Ensure does not throw exception
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
         room_alias_str = "#room:test"
 
-        self.store.count_real_users = Mock(return_value=make_awaitable(1))
+        self.store.count_real_users = Mock(return_value=make_awaitable(1))  # type: ignore[assignment]
         self.store.is_real_user = Mock(return_value=make_awaitable(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
         self,
     ) -> None:
-        self.store.count_real_users = Mock(return_value=make_awaitable(2))
+        self.store.count_real_users = Mock(return_value=make_awaitable(2))  # type: ignore[assignment]
         self.store.is_real_user = Mock(return_value=make_awaitable(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly not federated.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertFalse(room["federatable"])
         self.assertFalse(room["public"])
         self.assertEqual(room["join_rules"], "public")
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly a public room.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertEqual(room["join_rules"], "public")
 
         # Both users should be in the room.
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly a private room.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertFalse(room["public"])
         self.assertEqual(room["join_rules"], "invite")
         self.assertEqual(room["guest_access"], "can_join")
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly a private room.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertFalse(room["public"])
         self.assertEqual(room["join_rules"], "invite")
         self.assertEqual(room["guest_access"], "can_join")
@@ -503,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         # Lower the permissions of the inviter.
         event_creation_handler = self.hs.get_event_creation_handler()
         requester = create_requester(inviter)
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_creation_handler.create_event(
                 requester,
                 {
@@ -515,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
                 },
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_creation_handler.handle_new_client_event(
                 requester, events_and_context=[(event, context)]
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 9b1b8b9f13..b5c772a7ae 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # mock out the error renderer too
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
         request = _mock_request()
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler and error renderer
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
 
         # register a user to occupy the first-choice MXID
         store = self.hs.get_datastores().main
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # The response doesn't have the proper userGroup or department.
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
index 137deab138..d6f43a98fc 100644
--- a/tests/handlers/test_sso.py
+++ b/tests/handlers/test_sso.py
@@ -113,7 +113,6 @@ async def mock_get_file(
     headers: Optional[RawHeaders] = None,
     is_allowed_content_type: Optional[Callable[[str], bool]] = None,
 ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
-
     fake_response = FakeResponse(code=404)
     if url == "http://my.server/me.png":
         fake_response = FakeResponse(
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index f1a50c5bcb..d11ded6c5b 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
 
 
 class StatsRoomTests(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1fe9563c98..94518a7196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
 
         # we mock out the federation client too
-        mock_federation_client = Mock(spec=["put_json"])
-        mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+        self.mock_federation_client = Mock(spec=["put_json"])
+        self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
 
         # the tests assume that we are starting at unix time 1000
         reactor.pump((1000,))
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.mock_hs_notifier = Mock()
         hs = self.setup_test_homeserver(
             notifier=self.mock_hs_notifier,
-            federation_http_client=mock_federation_client,
+            federation_http_client=self.mock_federation_client,
             keyring=mock_keyring,
             replication_streams={},
         )
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        put_json = self.hs.get_federation_http_client().put_json
-        put_json.assert_called_once_with(
+        self.mock_federation_client.put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
             data=_expect_edu_transaction(
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        put_json = self.hs.get_federation_http_client().put_json
-        put_json.assert_called_once_with(
+        self.mock_federation_client.put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
             data=_expect_edu_transaction(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..a02c1c6227 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Tuple
+from typing import Any, Tuple
 from unittest.mock import Mock, patch
 from urllib.parse import quote
 
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
 from synapse.rest.client import login, register, room, user_directory
 from synapse.server import HomeServer
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import create_requester
+from synapse.types import UserProfile, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import override_config
 
 
+# A spam checker which doesn't implement anything, so create a bare object.
+class UselessSpamChecker:
+    def __init__(self, config: Any):
+        pass
+
+
 class UserDirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the UserDirectoryHandler.
 
@@ -186,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
         self._check_only_one_user_in_directory(user, room)
 
+    def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
+        """
+        Regression test: Test that search terms with colons in them are acceptable.
+        """
+        u1 = self.register_user("user1", "pass")
+        self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10))
+
     def test_user_not_in_users_table(self) -> None:
         """Unclear how it happens, but on matrix.org we've seen join events
         for users who aren't in the users table. Test that we don't fall over
@@ -773,7 +786,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 1)
 
-        async def allow_all(user_profile: ProfileInfo) -> bool:
+        async def allow_all(user_profile: UserProfile) -> bool:
             # Allow all users.
             return False
 
@@ -787,7 +800,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
 
         # Configure a spam checker that filters all users.
-        async def block_all(user_profile: ProfileInfo) -> bool:
+        async def block_all(user_profile: UserProfile) -> bool:
             # All users are spammy.
             return True
 
@@ -797,6 +810,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 0)
 
+    @override_config(
+        {
+            "spam_checker": {
+                "module": "tests.handlers.test_user_directory.UselessSpamChecker"
+            }
+        }
+    )
     def test_legacy_spam_checker(self) -> None:
         """
         A spam checker without the expected method should be ignored.
@@ -825,11 +845,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
         self.assertEqual(public_users, set())
 
-        # Configure a spam checker.
-        spam_checker = self.hs.get_spam_checker()
-        # The spam checker doesn't need any methods, so create a bare object.
-        spam_checker.spam_checker = object()
-
         # We get one search result when searching for user2 by user1.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 1)
@@ -949,13 +964,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(
-            self.hs.get_storage_controllers().persistence.persist_event(event, context)
-        )
+        context = self.get_success(unpersisted_context.persist(event))
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self.get_success(persistence.persist_event(event, context))
 
     def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
         """We've chosen to simplify the user directory's implementation by
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 093537adef..528cdee34b 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -19,13 +19,15 @@ from zope.interface import implementer
 
 from OpenSSL import SSL
 from OpenSSL.SSL import Connection
+from twisted.internet.address import IPv4Address
 from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
 from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.protocols.tls import TLSMemoryBIOProtocol
 from twisted.web.client import BrowserLikePolicyForHTTPS  # noqa: F401
 from twisted.web.iweb import IPolicyForHTTPS  # noqa: F401
 
 
-def get_test_https_policy():
+def get_test_https_policy() -> BrowserLikePolicyForHTTPS:
     """Get a test IPolicyForHTTPS which trusts the test CA cert
 
     Returns:
@@ -39,7 +41,7 @@ def get_test_https_policy():
     return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
 
 
-def get_test_ca_cert_file():
+def get_test_ca_cert_file() -> str:
     """Get the path to the test CA cert
 
     The keypair is generated with:
@@ -51,7 +53,7 @@ def get_test_ca_cert_file():
     return os.path.join(os.path.dirname(__file__), "ca.crt")
 
 
-def get_test_key_file():
+def get_test_key_file() -> str:
     """get the path to the test key
 
     The key file is made with:
@@ -137,15 +139,20 @@ class TestServerTLSConnectionFactory:
     """An SSL connection creator which returns connections which present a certificate
     signed by our test CA."""
 
-    def __init__(self, sanlist):
+    def __init__(self, sanlist: List[bytes]):
         """
         Args:
-            sanlist: list[bytes]: a list of subjectAltName values for the cert
+            sanlist: a list of subjectAltName values for the cert
         """
         self._cert_file = create_test_cert_file(sanlist)
 
-    def serverConnectionForTLS(self, tlsProtocol):
+    def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connection:
         ctx = SSL.Context(SSL.SSLv23_METHOD)
         ctx.use_certificate_file(self._cert_file)
         ctx.use_privatekey_file(get_test_key_file())
         return Connection(ctx, None)
+
+
+# A dummy address, useful for tests that use FakeTransport and don't care about where
+# packets are going to/coming from.
+dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 992d8f94fd..eb7f53fee5 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -14,7 +14,7 @@
 import base64
 import logging
 import os
-from typing import Iterable, Optional
+from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
 from unittest.mock import Mock, patch
 
 import treq
@@ -24,14 +24,19 @@ from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions
-from twisted.internet.interfaces import IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.defer import Deferred
+from twisted.internet.endpoints import _WrappingProtocol
+from twisted.internet.interfaces import (
+    IOpenSSLClientConnectionCreator,
+    IProtocolFactory,
+)
+from twisted.internet.protocol import Factory, Protocol
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web._newclient import ResponseNeverReceived
 from twisted.web.client import Agent
 from twisted.web.http import HTTPChannel, Request
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IPolicyForHTTPS
+from twisted.web.iweb import IPolicyForHTTPS, IResponse
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
@@ -42,27 +47,39 @@ from synapse.http.federation.well_known_resolver import (
     WellKnownResolver,
     _cache_period_from_headers,
 )
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import (
+    SENTINEL_CONTEXT,
+    LoggingContext,
+    LoggingContextOrSentinel,
+    current_context,
+)
+from synapse.types import ISynapseReactor
 from synapse.util.caches.ttlcache import TTLCache
 
 from tests import unittest
-from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
+from tests.http import (
+    TestServerTLSConnectionFactory,
+    dummy_address,
+    get_test_ca_cert_file,
+)
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
-from tests.utils import default_config
+from tests.utils import checked_cast, default_config
 
 logger = logging.getLogger(__name__)
 
 
 # Once Async Mocks or lambdas are supported this can go away.
-def generate_resolve_service(result):
-    async def resolve_service(_):
+def generate_resolve_service(
+    result: List[Server],
+) -> Callable[[Any], Awaitable[List[Server]]]:
+    async def resolve_service(_: Any) -> List[Server]:
         return result
 
     return resolve_service
 
 
 class MatrixFederationAgentTests(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor = ThreadedMemoryReactorClock()
 
         self.mock_resolver = Mock()
@@ -75,8 +92,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
         self.tls_factory = FederationPolicyForHTTPS(config)
 
-        self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
-        self.had_well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
+        self.well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache(
+            "test_cache", timer=self.reactor.seconds
+        )
+        self.had_well_known_cache: TTLCache[bytes, bool] = TTLCache(
+            "test_cache", timer=self.reactor.seconds
+        )
         self.well_known_resolver = WellKnownResolver(
             self.reactor,
             Agent(self.reactor, contextFactory=self.tls_factory),
@@ -89,8 +110,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self,
         client_factory: IProtocolFactory,
         ssl: bool = True,
-        expected_sni: bytes = None,
-        tls_sanlist: Optional[Iterable[bytes]] = None,
+        expected_sni: Optional[bytes] = None,
+        tls_sanlist: Optional[List[bytes]] = None,
     ) -> HTTPChannel:
         """Builds a test server, and completes the outgoing client connection
         Args:
@@ -116,8 +137,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
         if ssl:
             server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
 
-        server_protocol = server_factory.buildProtocol(None)
-
+        server_protocol = server_factory.buildProtocol(dummy_address)
+        assert server_protocol is not None
         # now, tell the client protocol factory to build the client protocol (it will be a
         # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
         # HTTP11ClientProtocol) and wire the output of said protocol up to the server via
@@ -125,7 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
         #
         # Normally this would be done by the TCP socket code in Twisted, but we are
         # stubbing that out here.
-        client_protocol = client_factory.buildProtocol(None)
+        # NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91)
+        client_protocol = checked_cast(
+            _WrappingProtocol, client_factory.buildProtocol(dummy_address)
+        )
         client_protocol.makeConnection(
             FakeTransport(server_protocol, self.reactor, client_protocol)
         )
@@ -136,6 +160,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         )
 
         if ssl:
+            assert isinstance(server_protocol, TLSMemoryBIOProtocol)
             # fish the test server back out of the server-side TLS protocol.
             http_protocol = server_protocol.wrappedProtocol
             # grab a hold of the TLS connection, in case it gets torn down
@@ -144,6 +169,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             http_protocol = server_protocol
             tls_connection = None
 
+        assert isinstance(http_protocol, HTTPChannel)
         # give the reactor a pump to get the TLS juices flowing (if needed)
         self.reactor.advance(0)
 
@@ -159,12 +185,14 @@ class MatrixFederationAgentTests(unittest.TestCase):
         return http_protocol
 
     @defer.inlineCallbacks
-    def _make_get_request(self, uri: bytes):
+    def _make_get_request(
+        self, uri: bytes
+    ) -> Generator["Deferred[object]", object, IResponse]:
         """
         Sends a simple GET request via the agent, and checks its logcontext management
         """
         with LoggingContext("one") as context:
-            fetch_d = self.agent.request(b"GET", uri)
+            fetch_d: Deferred[IResponse] = self.agent.request(b"GET", uri)
 
             # Nothing happened yet
             self.assertNoResult(fetch_d)
@@ -172,8 +200,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
             # should have reset logcontext to the sentinel
             _check_logcontext(SENTINEL_CONTEXT)
 
+            fetch_res: IResponse
             try:
-                fetch_res = yield fetch_d
+                fetch_res = yield fetch_d  # type: ignore[misc, assignment]
                 return fetch_res
             except Exception as e:
                 logger.info("Fetch of %s failed: %s", uri.decode("ascii"), e)
@@ -216,7 +245,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         request: Request,
         content: bytes,
         headers: Optional[dict] = None,
-    ):
+    ) -> None:
         """Check that an incoming request looks like a valid .well-known request, and
         send back the response.
         """
@@ -237,16 +266,16 @@ class MatrixFederationAgentTests(unittest.TestCase):
         because it is created too early during setUp
         """
         return MatrixFederationAgent(
-            reactor=self.reactor,
+            reactor=cast(ISynapseReactor, self.reactor),
             tls_client_options_factory=self.tls_factory,
-            user_agent="test-agent",  # Note that this is unused since _well_known_resolver is provided.
+            user_agent=b"test-agent",  # Note that this is unused since _well_known_resolver is provided.
             ip_whitelist=IPSet(),
             ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=self.well_known_resolver,
         )
 
-    def test_get(self):
+    def test_get(self) -> None:
         """happy-path test of a GET request with an explicit port"""
         self._do_get()
 
@@ -254,11 +283,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
         os.environ,
         {"https_proxy": "proxy.com", "no_proxy": "testserv"},
     )
-    def test_get_bypass_proxy(self):
+    def test_get_bypass_proxy(self) -> None:
         """test of a GET request with an explicit port and bypass proxy"""
         self._do_get()
 
-    def _do_get(self):
+    def _do_get(self) -> None:
         """test of a GET request with an explicit port"""
         self.agent = self._make_agent()
 
@@ -318,7 +347,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
     @patch.dict(
         os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"}
     )
-    def test_get_via_http_proxy(self):
+    def test_get_via_http_proxy(self) -> None:
         """test for federation request through a http proxy"""
         self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None)
 
@@ -326,7 +355,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         os.environ,
         {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"},
     )
-    def test_get_via_http_proxy_with_auth(self):
+    def test_get_via_http_proxy_with_auth(self) -> None:
         """test for federation request through a http proxy with authentication"""
         self._do_get_via_proxy(
             expect_proxy_ssl=False, expected_auth_credentials=b"user:pass"
@@ -335,7 +364,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
     @patch.dict(
         os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
     )
-    def test_get_via_https_proxy(self):
+    def test_get_via_https_proxy(self) -> None:
         """test for federation request through a https proxy"""
         self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None)
 
@@ -343,7 +372,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         os.environ,
         {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"},
     )
-    def test_get_via_https_proxy_with_auth(self):
+    def test_get_via_https_proxy_with_auth(self) -> None:
         """test for federation request through a https proxy with authentication"""
         self._do_get_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=b"user:pass"
@@ -353,7 +382,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self,
         expect_proxy_ssl: bool = False,
         expected_auth_credentials: Optional[bytes] = None,
-    ):
+    ) -> None:
         """Send a https federation request via an agent and check that it is correctly
             received at the proxy and client. The proxy can use either http or https.
         Args:
@@ -418,10 +447,11 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # now we make another test server to act as the upstream HTTP server.
         server_ssl_protocol = _wrap_server_factory_for_tls(
             _get_test_protocol_factory()
-        ).buildProtocol(None)
+        ).buildProtocol(dummy_address)
 
         # Tell the HTTP server to send outgoing traffic back via the proxy's transport.
         proxy_server_transport = proxy_server.transport
+        assert proxy_server_transport is not None
         server_ssl_protocol.makeConnection(proxy_server_transport)
 
         # ... and replace the protocol on the proxy's transport with the
@@ -436,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
         else:
             assert isinstance(proxy_server_transport, FakeTransport)
             client_protocol = proxy_server_transport.other
-            c2s_transport = client_protocol.transport
+            assert isinstance(client_protocol, Protocol)
+            c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
             c2s_transport.other = server_ssl_protocol
 
         self.reactor.advance(0)
@@ -451,6 +482,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
         # now there should be a pending request
         http_server = server_ssl_protocol.wrappedProtocol
+        assert isinstance(http_server, HTTPChannel)
         self.assertEqual(len(http_server.requests), 1)
 
         request = http_server.requests[0]
@@ -491,7 +523,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         json = self.successResultOf(treq.json_content(response))
         self.assertEqual(json, {"a": 1})
 
-    def test_get_ip_address(self):
+    def test_get_ip_address(self) -> None:
         """
         Test the behaviour when the server name contains an explicit IP (with no port)
         """
@@ -526,7 +558,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_get_ipv6_address(self):
+    def test_get_ipv6_address(self) -> None:
         """
         Test the behaviour when the server name contains an explicit IPv6 address
         (with no port)
@@ -562,7 +594,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_get_ipv6_address_with_port(self):
+    def test_get_ipv6_address_with_port(self) -> None:
         """
         Test the behaviour when the server name contains an explicit IPv6 address
         (with explicit port)
@@ -598,7 +630,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_get_hostname_bad_cert(self):
+    def test_get_hostname_bad_cert(self) -> None:
         """
         Test the behaviour when the certificate on the server doesn't match the hostname
         """
@@ -651,7 +683,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         failure_reason = e.value.reasons[0]
         self.assertIsInstance(failure_reason.value, VerificationError)
 
-    def test_get_ip_address_bad_cert(self):
+    def test_get_ip_address_bad_cert(self) -> None:
         """
         Test the behaviour when the server name contains an explicit IP, but
         the server cert doesn't cover it
@@ -684,7 +716,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         failure_reason = e.value.reasons[0]
         self.assertIsInstance(failure_reason.value, VerificationError)
 
-    def test_get_no_srv_no_well_known(self):
+    def test_get_no_srv_no_well_known(self) -> None:
         """
         Test the behaviour when the server name has no port, no SRV, and no well-known
         """
@@ -740,7 +772,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_get_well_known(self):
+    def test_get_well_known(self) -> None:
         """Test the behaviour when the .well-known delegates elsewhere"""
         self.agent = self._make_agent()
 
@@ -802,7 +834,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.well_known_cache.expire()
         self.assertNotIn(b"testserv", self.well_known_cache)
 
-    def test_get_well_known_redirect(self):
+    def test_get_well_known_redirect(self) -> None:
         """Test the behaviour when the server name has no port and no SRV record, but
         the .well-known has a 300 redirect
         """
@@ -892,7 +924,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.well_known_cache.expire()
         self.assertNotIn(b"testserv", self.well_known_cache)
 
-    def test_get_invalid_well_known(self):
+    def test_get_invalid_well_known(self) -> None:
         """
         Test the behaviour when the server name has an *invalid* well-known (and no SRV)
         """
@@ -945,7 +977,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_get_well_known_unsigned_cert(self):
+    def test_get_well_known_unsigned_cert(self) -> None:
         """Test the behaviour when the .well-known server presents a cert
         not signed by a CA
         """
@@ -969,7 +1001,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             ip_blacklist=IPSet(),
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=WellKnownResolver(
-                self.reactor,
+                cast(ISynapseReactor, self.reactor),
                 Agent(self.reactor, contextFactory=tls_factory),
                 b"test-agent",
                 well_known_cache=self.well_known_cache,
@@ -999,7 +1031,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
             b"_matrix._tcp.testserv"
         )
 
-    def test_get_hostname_srv(self):
+    def test_get_hostname_srv(self) -> None:
         """
         Test the behaviour when there is a single SRV record
         """
@@ -1041,7 +1073,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_get_well_known_srv(self):
+    def test_get_well_known_srv(self) -> None:
         """Test the behaviour when the .well-known redirects to a place where there
         is a SRV.
         """
@@ -1101,7 +1133,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_idna_servername(self):
+    def test_idna_servername(self) -> None:
         """test the behaviour when the server name has idna chars in"""
         self.agent = self._make_agent()
 
@@ -1163,7 +1195,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_idna_srv_target(self):
+    def test_idna_srv_target(self) -> None:
         """test the behaviour when the target of a SRV record has idna chars"""
         self.agent = self._make_agent()
 
@@ -1206,7 +1238,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((0.1,))
         self.successResultOf(test_d)
 
-    def test_well_known_cache(self):
+    def test_well_known_cache(self) -> None:
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         fetch_d = defer.ensureDeferred(
@@ -1262,7 +1294,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         r = self.successResultOf(fetch_d)
         self.assertEqual(r.delegated_server, b"other-server")
 
-    def test_well_known_cache_with_temp_failure(self):
+    def test_well_known_cache_with_temp_failure(self) -> None:
         """Test that we refetch well-known before the cache expires, and that
         it ignores transient errors.
         """
@@ -1341,7 +1373,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         r = self.successResultOf(fetch_d)
         self.assertEqual(r.delegated_server, None)
 
-    def test_well_known_too_large(self):
+    def test_well_known_too_large(self) -> None:
         """A well-known query that returns a result which is too large should be rejected."""
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
@@ -1367,7 +1399,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         r = self.successResultOf(fetch_d)
         self.assertIsNone(r.delegated_server)
 
-    def test_srv_fallbacks(self):
+    def test_srv_fallbacks(self) -> None:
         """Test that other SRV results are tried if the first one fails."""
         self.agent = self._make_agent()
 
@@ -1427,7 +1459,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
 
 class TestCachePeriodFromHeaders(unittest.TestCase):
-    def test_cache_control(self):
+    def test_cache_control(self) -> None:
         # uppercase
         self.assertEqual(
             _cache_period_from_headers(
@@ -1464,7 +1496,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
             0,
         )
 
-    def test_expires(self):
+    def test_expires(self) -> None:
         self.assertEqual(
             _cache_period_from_headers(
                 Headers({b"Expires": [b"Wed, 30 Jan 2019 07:35:33 GMT"]}),
@@ -1491,15 +1523,15 @@ class TestCachePeriodFromHeaders(unittest.TestCase):
         self.assertEqual(_cache_period_from_headers(Headers({b"Expires": [b"0"]})), 0)
 
 
-def _check_logcontext(context):
+def _check_logcontext(context: LoggingContextOrSentinel) -> None:
     current = current_context()
     if current is not context:
         raise AssertionError("Expected logcontext %s but was %s" % (context, current))
 
 
 def _wrap_server_factory_for_tls(
-    factory: IProtocolFactory, sanlist: Iterable[bytes] = None
-) -> IProtocolFactory:
+    factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
+) -> TLSMemoryBIOFactory:
     """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
     The resultant factory will create a TLS server which presents a certificate
     signed by our test CA, valid for the domains in `sanlist`
@@ -1537,7 +1569,7 @@ def _get_test_protocol_factory() -> IProtocolFactory:
     return server_factory
 
 
-def _log_request(request: str):
+def _log_request(request: str) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info(f"Completed request {request}")
 
@@ -1547,6 +1579,8 @@ class TrustingTLSPolicyForHTTPS:
     """An IPolicyForHTTPS which checks that the certificate belongs to the
     right server, but doesn't check the certificate chain."""
 
-    def creatorForNetloc(self, hostname, port):
+    def creatorForNetloc(
+        self, hostname: bytes, port: int
+    ) -> IOpenSSLClientConnectionCreator:
         certificateOptions = OpenSSLCertificateOptions()
         return ClientTLSOptions(hostname, certificateOptions.getContext())
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index 77ce8432ac..6ab13357f9 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from typing import Dict, Generator, List, Tuple, cast
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -20,7 +20,7 @@ from twisted.internet.defer import Deferred
 from twisted.internet.error import ConnectError
 from twisted.names import dns, error
 
-from synapse.http.federation.srv_resolver import SrvResolver
+from synapse.http.federation.srv_resolver import Server, SrvResolver
 from synapse.logging.context import LoggingContext, current_context
 
 from tests import unittest
@@ -28,7 +28,7 @@ from tests.utils import MockClock
 
 
 class SrvResolverTestCase(unittest.TestCase):
-    def test_resolve(self):
+    def test_resolve(self) -> None:
         dns_client_mock = Mock()
 
         service_name = b"test_service.example.com"
@@ -38,18 +38,18 @@ class SrvResolverTestCase(unittest.TestCase):
             type=dns.SRV, payload=dns.Record_SRV(target=host_name)
         )
 
-        result_deferred = Deferred()
+        result_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
         dns_client_mock.lookupService.return_value = result_deferred
 
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         @defer.inlineCallbacks
-        def do_lookup():
-
+        def do_lookup() -> Generator["Deferred[object]", object, List[Server]]:
             with LoggingContext("one") as ctx:
                 resolve_d = resolver.resolve_service(service_name)
-                result = yield defer.ensureDeferred(resolve_d)
+                result: List[Server]
+                result = yield defer.ensureDeferred(resolve_d)  # type: ignore[assignment]
 
                 # should have restored our context
                 self.assertIs(current_context(), ctx)
@@ -70,7 +70,9 @@ class SrvResolverTestCase(unittest.TestCase):
         self.assertEqual(servers[0].host, host_name)
 
     @defer.inlineCallbacks
-    def test_from_cache_expired_and_dns_fail(self):
+    def test_from_cache_expired_and_dns_fail(
+        self,
+    ) -> Generator["Deferred[object]", object, None]:
         dns_client_mock = Mock()
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
 
@@ -81,10 +83,13 @@ class SrvResolverTestCase(unittest.TestCase):
         entry.priority = 0
         entry.weight = 0
 
-        cache = {service_name: [entry]}
+        cache = {service_name: [cast(Server, entry)]}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+        servers: List[Server]
+        servers = yield defer.ensureDeferred(
+            resolver.resolve_service(service_name)
+        )  # type: ignore[assignment]
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
 
@@ -92,7 +97,7 @@ class SrvResolverTestCase(unittest.TestCase):
         self.assertEqual(servers, cache[service_name])
 
     @defer.inlineCallbacks
-    def test_from_cache(self):
+    def test_from_cache(self) -> Generator["Deferred[object]", object, None]:
         clock = MockClock()
 
         dns_client_mock = Mock(spec_set=["lookupService"])
@@ -105,12 +110,15 @@ class SrvResolverTestCase(unittest.TestCase):
         entry.priority = 0
         entry.weight = 0
 
-        cache = {service_name: [entry]}
+        cache = {service_name: [cast(Server, entry)]}
         resolver = SrvResolver(
             dns_client=dns_client_mock, cache=cache, get_time=clock.time
         )
 
-        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+        servers: List[Server]
+        servers = yield defer.ensureDeferred(
+            resolver.resolve_service(service_name)
+        )  # type: ignore[assignment]
 
         self.assertFalse(dns_client_mock.lookupService.called)
 
@@ -118,45 +126,48 @@ class SrvResolverTestCase(unittest.TestCase):
         self.assertEqual(servers, cache[service_name])
 
     @defer.inlineCallbacks
-    def test_empty_cache(self):
+    def test_empty_cache(self) -> Generator["Deferred[object]", object, None]:
         dns_client_mock = Mock()
 
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError())
 
         service_name = b"test_service.example.com"
 
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         with self.assertRaises(error.DNSServerError):
             yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
     @defer.inlineCallbacks
-    def test_name_error(self):
+    def test_name_error(self) -> Generator["Deferred[object]", object, None]:
         dns_client_mock = Mock()
 
         dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError())
 
         service_name = b"test_service.example.com"
 
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
+        servers: List[Server]
+        servers = yield defer.ensureDeferred(
+            resolver.resolve_service(service_name)
+        )  # type: ignore[assignment]
 
         self.assertEqual(len(servers), 0)
         self.assertEqual(len(cache), 0)
 
-    def test_disabled_service(self):
+    def test_disabled_service(self) -> None:
         """
         test the behaviour when there is a single record which is ".".
         """
         service_name = b"test_service.example.com"
 
-        lookup_deferred = Deferred()
+        lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
         dns_client_mock = Mock()
         dns_client_mock.lookupService.return_value = lookup_deferred
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
@@ -173,16 +184,16 @@ class SrvResolverTestCase(unittest.TestCase):
 
         self.failureResultOf(resolve_d, ConnectError)
 
-    def test_non_srv_answer(self):
+    def test_non_srv_answer(self) -> None:
         """
         test the behaviour when the dns server gives us a spurious non-SRV response
         """
         service_name = b"test_service.example.com"
 
-        lookup_deferred = Deferred()
+        lookup_deferred: "Deferred[Tuple[List[dns.RRHeader], None, None]]" = Deferred()
         dns_client_mock = Mock()
         dns_client_mock.lookupService.return_value = lookup_deferred
-        cache = {}
+        cache: Dict[bytes, List[Server]] = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         # Old versions of Twisted don't have an ensureDeferred in successResultOf.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 5071f83574..36472e57a8 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -556,6 +556,6 @@ def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str:
     return method_name
 
 
-def _hash_stack(stack: List[inspect.FrameInfo]):
+def _hash_stack(stack: List[inspect.FrameInfo]) -> Tuple[str, ...]:
     """Turns a stack into a hashable value that can be put into a set."""
     return tuple(_format_stack_frame(frame) for frame in stack)
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 391196425c..ec6aacf235 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -11,28 +11,34 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any
 
+from twisted.web.server import Request
 
 from synapse.http.additional_resource import AdditionalResource
 from synapse.http.server import respond_with_json
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from tests.server import FakeSite, make_request
 from tests.unittest import HomeserverTestCase
 
 
 class _AsyncTestCustomEndpoint:
-    def __init__(self, config, module_api):
+    def __init__(self, config: JsonDict, module_api: Any) -> None:
         pass
 
-    async def handle_request(self, request):
+    async def handle_request(self, request: Request) -> None:
+        assert isinstance(request, SynapseRequest)
         respond_with_json(request, 200, {"some_key": "some_value_async"})
 
 
 class _SyncTestCustomEndpoint:
-    def __init__(self, config, module_api):
+    def __init__(self, config: JsonDict, module_api: Any) -> None:
         pass
 
-    async def handle_request(self, request):
+    async def handle_request(self, request: Request) -> None:
+        assert isinstance(request, SynapseRequest)
         respond_with_json(request, 200, {"some_key": "some_value_sync"})
 
 
@@ -41,7 +47,7 @@ class AdditionalResourceTests(HomeserverTestCase):
     and async handlers.
     """
 
-    def test_async(self):
+    def test_async(self) -> None:
         handler = _AsyncTestCustomEndpoint({}, None).handle_request
         resource = AdditionalResource(self.hs, handler)
 
@@ -52,7 +58,7 @@ class AdditionalResourceTests(HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
 
-    def test_sync(self):
+    def test_sync(self) -> None:
         handler = _SyncTestCustomEndpoint({}, None).handle_request
         resource = AdditionalResource(self.hs, handler)
 
diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 7e2f2a01cc..57b6a84e23 100644
--- a/tests/http/test_client.py
+++ b/tests/http/test_client.py
@@ -13,10 +13,12 @@
 #  limitations under the License.
 
 from io import BytesIO
+from typing import Tuple, Union
 from unittest.mock import Mock
 
 from netaddr import IPSet
 
+from twisted.internet.defer import Deferred
 from twisted.internet.error import DNSLookupError
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol
@@ -28,6 +30,7 @@ from synapse.http.client import (
     BlacklistingAgentWrapper,
     BlacklistingReactorWrapper,
     BodyExceededMaxSize,
+    _DiscardBodyWithMaxSizeProtocol,
     read_body_with_max_size,
 )
 
@@ -36,7 +39,9 @@ from tests.unittest import TestCase
 
 
 class ReadBodyWithMaxSizeTests(TestCase):
-    def _build_response(self, length=UNKNOWN_LENGTH):
+    def _build_response(
+        self, length: Union[int, str] = UNKNOWN_LENGTH
+    ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
         """Start reading the body, returns the response, result and proto"""
         response = Mock(length=length)
         result = BytesIO()
@@ -48,23 +53,27 @@ class ReadBodyWithMaxSizeTests(TestCase):
 
         return result, deferred, protocol
 
-    def _assert_error(self, deferred, protocol):
+    def _assert_error(
+        self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol
+    ) -> None:
         """Ensure that the expected error is received."""
-        self.assertIsInstance(deferred.result, Failure)
+        assert isinstance(deferred.result, Failure)
         self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
-        protocol.transport.abortConnection.assert_called_once()
+        assert protocol.transport is not None
+        # type-ignore: presumably abortConnection has been replaced with a Mock.
+        protocol.transport.abortConnection.assert_called_once()  # type: ignore[attr-defined]
 
-    def _cleanup_error(self, deferred):
+    def _cleanup_error(self, deferred: "Deferred[int]") -> None:
         """Ensure that the error in the Deferred is handled gracefully."""
         called = [False]
 
-        def errback(f):
+        def errback(f: Failure) -> None:
             called[0] = True
 
         deferred.addErrback(errback)
         self.assertTrue(called[0])
 
-    def test_no_error(self):
+    def test_no_error(self) -> None:
         """A response that is NOT too large."""
         result, deferred, protocol = self._build_response()
 
@@ -76,7 +85,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self.assertEqual(result.getvalue(), b"12345")
         self.assertEqual(deferred.result, 5)
 
-    def test_too_large(self):
+    def test_too_large(self) -> None:
         """A response which is too large raises an exception."""
         result, deferred, protocol = self._build_response()
 
@@ -87,7 +96,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self._assert_error(deferred, protocol)
         self._cleanup_error(deferred)
 
-    def test_multiple_packets(self):
+    def test_multiple_packets(self) -> None:
         """Data should be accumulated through mutliple packets."""
         result, deferred, protocol = self._build_response()
 
@@ -100,7 +109,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self.assertEqual(result.getvalue(), b"1234")
         self.assertEqual(deferred.result, 4)
 
-    def test_additional_data(self):
+    def test_additional_data(self) -> None:
         """A connection can receive data after being closed."""
         result, deferred, protocol = self._build_response()
 
@@ -115,7 +124,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
         self._assert_error(deferred, protocol)
         self._cleanup_error(deferred)
 
-    def test_content_length(self):
+    def test_content_length(self) -> None:
         """The body shouldn't be read (at all) if the Content-Length header is too large."""
         result, deferred, protocol = self._build_response(length=10)
 
@@ -132,7 +141,7 @@ class ReadBodyWithMaxSizeTests(TestCase):
 
 
 class BlacklistingAgentTest(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor, self.clock = get_clock()
 
         self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
@@ -140,7 +149,7 @@ class BlacklistingAgentTest(TestCase):
         self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
 
         # Configure the reactor's DNS resolver.
-        for (domain, ip) in (
+        for domain, ip in (
             (self.safe_domain, self.safe_ip),
             (self.unsafe_domain, self.unsafe_ip),
             (self.allowed_domain, self.allowed_ip),
@@ -151,7 +160,7 @@ class BlacklistingAgentTest(TestCase):
         self.ip_whitelist = IPSet([self.allowed_ip.decode()])
         self.ip_blacklist = IPSet(["5.0.0.0/8"])
 
-    def test_reactor(self):
+    def test_reactor(self) -> None:
         """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
         agent = Agent(
             BlacklistingReactorWrapper(
@@ -197,12 +206,12 @@ class BlacklistingAgentTest(TestCase):
             response = self.successResultOf(d)
             self.assertEqual(response.code, 200)
 
-    def test_agent(self):
+    def test_agent(self) -> None:
         """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
         agent = BlacklistingAgentWrapper(
             Agent(self.reactor),
-            ip_whitelist=self.ip_whitelist,
             ip_blacklist=self.ip_blacklist,
+            ip_whitelist=self.ip_whitelist,
         )
 
         # The unsafe IPs should be rejected.
diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py
index a801f002a0..8c18e56881 100644
--- a/tests/http/test_endpoint.py
+++ b/tests/http/test_endpoint.py
@@ -17,7 +17,7 @@ from tests import unittest
 
 
 class ServerNameTestCase(unittest.TestCase):
-    def test_parse_server_name(self):
+    def test_parse_server_name(self) -> None:
         test_data = {
             "localhost": ("localhost", None),
             "my-example.com:1234": ("my-example.com", 1234),
@@ -32,7 +32,7 @@ class ServerNameTestCase(unittest.TestCase):
         for i, o in test_data.items():
             self.assertEqual(parse_server_name(i), o)
 
-    def test_validate_bad_server_names(self):
+    def test_validate_bad_server_names(self) -> None:
         test_data = [
             "",  # empty
             "localhost:http",  # non-numeric port
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index 9b64890bdb..9373523dea 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -11,16 +11,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+from typing import Generator
 from unittest.mock import Mock
 
 from netaddr import IPSet
 from parameterized import parameterized
 
 from twisted.internet import defer
-from twisted.internet.defer import TimeoutError
+from twisted.internet.defer import Deferred, TimeoutError
 from twisted.internet.error import ConnectingCancelledError, DNSLookupError
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
 from twisted.web.client import ResponseNeverReceived
 from twisted.web.http import HTTPChannel
 
@@ -30,34 +30,43 @@ from synapse.http.matrixfederationclient import (
     MatrixFederationHttpClient,
     MatrixFederationRequest,
 )
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import (
+    SENTINEL_CONTEXT,
+    LoggingContext,
+    LoggingContextOrSentinel,
+    current_context,
+)
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.server import FakeTransport
 from tests.unittest import HomeserverTestCase
 
 
-def check_logcontext(context):
+def check_logcontext(context: LoggingContextOrSentinel) -> None:
     current = current_context()
     if current is not context:
         raise AssertionError("Expected logcontext %s but was %s" % (context, current))
 
 
 class FederationClientTests(HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
         return hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.cl = MatrixFederationHttpClient(self.hs, None)
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-    def test_client_get(self):
+    def test_client_get(self) -> None:
         """
         happy-path test of a GET request
         """
 
         @defer.inlineCallbacks
-        def do_request():
+        def do_request() -> Generator["Deferred[object]", object, object]:
             with LoggingContext("one") as context:
                 fetch_d = defer.ensureDeferred(
                     self.cl.get_json("testserv:8008", "foo/bar")
@@ -119,7 +128,7 @@ class FederationClientTests(HomeserverTestCase):
         # check the response is as expected
         self.assertEqual(res, {"a": 1})
 
-    def test_dns_error(self):
+    def test_dns_error(self) -> None:
         """
         If the DNS lookup returns an error, it will bubble up.
         """
@@ -132,7 +141,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value, RequestSendFailed)
         self.assertIsInstance(f.value.inner_exception, DNSLookupError)
 
-    def test_client_connection_refused(self):
+    def test_client_connection_refused(self) -> None:
         d = defer.ensureDeferred(
             self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
         )
@@ -156,7 +165,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value, RequestSendFailed)
         self.assertIs(f.value.inner_exception, e)
 
-    def test_client_never_connect(self):
+    def test_client_never_connect(self) -> None:
         """
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError or TimeoutError.
@@ -188,7 +197,7 @@ class FederationClientTests(HomeserverTestCase):
             f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
         )
 
-    def test_client_connect_no_response(self):
+    def test_client_connect_no_response(self) -> None:
         """
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
@@ -222,7 +231,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value, RequestSendFailed)
         self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
 
-    def test_client_ip_range_blacklist(self):
+    def test_client_ip_range_blacklist(self) -> None:
         """Ensure that Synapse does not try to connect to blacklisted IPs"""
 
         # Set up the ip_range blacklist
@@ -292,7 +301,7 @@ class FederationClientTests(HomeserverTestCase):
         f = self.failureResultOf(d, RequestSendFailed)
         self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
 
-    def test_client_gets_headers(self):
+    def test_client_gets_headers(self) -> None:
         """
         Once the client gets the headers, _request returns successfully.
         """
@@ -319,7 +328,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertEqual(r.code, 200)
 
     @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
-    def test_timeout_reading_body(self, method_name: str):
+    def test_timeout_reading_body(self, method_name: str) -> None:
         """
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a RequestSendFailed with can_retry.
@@ -351,7 +360,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertTrue(f.value.can_retry)
         self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
 
-    def test_client_requires_trailing_slashes(self):
+    def test_client_requires_trailing_slashes(self) -> None:
         """
         If a connection is made to a client but the client rejects it due to
         requiring a trailing slash. We need to retry the request with a
@@ -405,7 +414,7 @@ class FederationClientTests(HomeserverTestCase):
         r = self.successResultOf(d)
         self.assertEqual(r, {})
 
-    def test_client_does_not_retry_on_400_plus(self):
+    def test_client_does_not_retry_on_400_plus(self) -> None:
         """
         Another test for trailing slashes but now test that we don't retry on
         trailing slashes on a non-400/M_UNRECOGNIZED response.
@@ -450,7 +459,7 @@ class FederationClientTests(HomeserverTestCase):
         # We should get a 404 failure response
         self.failureResultOf(d)
 
-    def test_client_sends_body(self):
+    def test_client_sends_body(self) -> None:
         defer.ensureDeferred(  # type: ignore[unused-awaitable]
             self.cl.post_json(
                 "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
@@ -474,7 +483,7 @@ class FederationClientTests(HomeserverTestCase):
         content = request.content.read()
         self.assertEqual(content, b'{"a":"b"}')
 
-    def test_closes_connection(self):
+    def test_closes_connection(self) -> None:
         """Check that the client closes unused HTTP connections"""
         d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
 
@@ -514,7 +523,7 @@ class FederationClientTests(HomeserverTestCase):
         self.assertTrue(conn.disconnecting)
 
     @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
-    def test_json_error(self, return_value):
+    def test_json_error(self, return_value: bytes) -> None:
         """
         Test what happens if invalid JSON is returned from the remote endpoint.
         """
@@ -560,7 +569,7 @@ class FederationClientTests(HomeserverTestCase):
         f = self.failureResultOf(test_d)
         self.assertIsInstance(f.value, RequestSendFailed)
 
-    def test_too_big(self):
+    def test_too_big(self) -> None:
         """
         Test what happens if a huge response is returned from the remote endpoint.
         """
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 2db77c6a73..cc175052ac 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -14,7 +14,7 @@
 import base64
 import logging
 import os
-from typing import Iterable, Optional
+from typing import List, Optional
 from unittest.mock import patch
 
 import treq
@@ -22,9 +22,13 @@ from netaddr import IPSet
 from parameterized import parameterized
 
 from twisted.internet import interfaces  # noqa: F401
-from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint
+from twisted.internet.endpoints import (
+    HostnameEndpoint,
+    _WrapperEndpoint,
+    _WrappingProtocol,
+)
 from twisted.internet.interfaces import IProtocol, IProtocolFactory
-from twisted.internet.protocol import Factory
+from twisted.internet.protocol import Factory, Protocol
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
 from twisted.web.http import HTTPChannel
 
@@ -32,9 +36,14 @@ from synapse.http.client import BlacklistingReactorWrapper
 from synapse.http.connectproxyclient import ProxyCredentials
 from synapse.http.proxyagent import ProxyAgent, parse_proxy
 
-from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.http import (
+    TestServerTLSConnectionFactory,
+    dummy_address,
+    get_test_https_policy,
+)
 from tests.server import FakeTransport, ThreadedMemoryReactorClock
 from tests.unittest import TestCase
+from tests.utils import checked_cast
 
 logger = logging.getLogger(__name__)
 
@@ -183,7 +192,7 @@ class ProxyParserTests(TestCase):
         expected_hostname: bytes,
         expected_port: int,
         expected_credentials: Optional[bytes],
-    ):
+    ) -> None:
         """
         Tests that a given proxy URL will be broken into the components.
         Args:
@@ -209,7 +218,7 @@ class ProxyParserTests(TestCase):
 
 
 class MatrixFederationAgentTests(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor = ThreadedMemoryReactorClock()
 
     def _make_connection(
@@ -218,7 +227,7 @@ class MatrixFederationAgentTests(TestCase):
         server_factory: IProtocolFactory,
         ssl: bool = False,
         expected_sni: Optional[bytes] = None,
-        tls_sanlist: Optional[Iterable[bytes]] = None,
+        tls_sanlist: Optional[List[bytes]] = None,
     ) -> IProtocol:
         """Builds a test server, and completes the outgoing client connection
 
@@ -244,7 +253,8 @@ class MatrixFederationAgentTests(TestCase):
         if ssl:
             server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
 
-        server_protocol = server_factory.buildProtocol(None)
+        server_protocol = server_factory.buildProtocol(dummy_address)
+        assert server_protocol is not None
 
         # now, tell the client protocol factory to build the client protocol,
         # and wire the output of said protocol up to the server via
@@ -252,7 +262,8 @@ class MatrixFederationAgentTests(TestCase):
         #
         # Normally this would be done by the TCP socket code in Twisted, but we are
         # stubbing that out here.
-        client_protocol = client_factory.buildProtocol(None)
+        client_protocol = client_factory.buildProtocol(dummy_address)
+        assert client_protocol is not None
         client_protocol.makeConnection(
             FakeTransport(server_protocol, self.reactor, client_protocol)
         )
@@ -263,6 +274,7 @@ class MatrixFederationAgentTests(TestCase):
         )
 
         if ssl:
+            assert isinstance(server_protocol, TLSMemoryBIOProtocol)
             http_protocol = server_protocol.wrappedProtocol
             tls_connection = server_protocol._tlsConnection
         else:
@@ -288,7 +300,7 @@ class MatrixFederationAgentTests(TestCase):
         scheme: bytes,
         hostname: bytes,
         path: bytes,
-    ):
+    ) -> None:
         """Runs a test case for a direct connection not going through a proxy.
 
         Args:
@@ -319,6 +331,7 @@ class MatrixFederationAgentTests(TestCase):
             ssl=is_https,
             expected_sni=hostname if is_https else None,
         )
+        assert isinstance(http_server, HTTPChannel)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -339,34 +352,34 @@ class MatrixFederationAgentTests(TestCase):
         body = self.successResultOf(treq.content(resp))
         self.assertEqual(body, b"result")
 
-    def test_http_request(self):
+    def test_http_request(self) -> None:
         agent = ProxyAgent(self.reactor)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
-    def test_https_request(self):
+    def test_https_request(self) -> None:
         agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
         self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
 
-    def test_http_request_use_proxy_empty_environment(self):
+    def test_http_request_use_proxy_empty_environment(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
-    def test_http_request_via_uppercase_no_proxy(self):
+    def test_http_request_via_uppercase_no_proxy(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(
         os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
     )
-    def test_http_request_via_no_proxy(self):
+    def test_http_request_via_no_proxy(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(
         os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
     )
-    def test_https_request_via_no_proxy(self):
+    def test_https_request_via_no_proxy(self) -> None:
         agent = ProxyAgent(
             self.reactor,
             contextFactory=get_test_https_policy(),
@@ -375,12 +388,12 @@ class MatrixFederationAgentTests(TestCase):
         self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
-    def test_http_request_via_no_proxy_star(self):
+    def test_http_request_via_no_proxy_star(self) -> None:
         agent = ProxyAgent(self.reactor, use_proxy=True)
         self._test_request_direct_connection(agent, b"http", b"test.com", b"")
 
     @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
-    def test_https_request_via_no_proxy_star(self):
+    def test_https_request_via_no_proxy_star(self) -> None:
         agent = ProxyAgent(
             self.reactor,
             contextFactory=get_test_https_policy(),
@@ -389,7 +402,7 @@ class MatrixFederationAgentTests(TestCase):
         self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
-    def test_http_request_via_proxy(self):
+    def test_http_request_via_proxy(self) -> None:
         """
         Tests that requests can be made through a proxy.
         """
@@ -401,7 +414,7 @@ class MatrixFederationAgentTests(TestCase):
         os.environ,
         {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"},
     )
-    def test_http_request_via_proxy_with_auth(self):
+    def test_http_request_via_proxy_with_auth(self) -> None:
         """
         Tests that authenticated requests can be made through a proxy.
         """
@@ -412,7 +425,7 @@ class MatrixFederationAgentTests(TestCase):
     @patch.dict(
         os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"}
     )
-    def test_http_request_via_https_proxy(self):
+    def test_http_request_via_https_proxy(self) -> None:
         self._do_http_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=None
         )
@@ -424,13 +437,13 @@ class MatrixFederationAgentTests(TestCase):
             "no_proxy": "unused.com",
         },
     )
-    def test_http_request_via_https_proxy_with_auth(self):
+    def test_http_request_via_https_proxy_with_auth(self) -> None:
         self._do_http_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
         )
 
     @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
-    def test_https_request_via_proxy(self):
+    def test_https_request_via_proxy(self) -> None:
         """Tests that TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=False, expected_auth_credentials=None
@@ -440,7 +453,7 @@ class MatrixFederationAgentTests(TestCase):
         os.environ,
         {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
     )
-    def test_https_request_via_proxy_with_auth(self):
+    def test_https_request_via_proxy_with_auth(self) -> None:
         """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies"
@@ -449,7 +462,7 @@ class MatrixFederationAgentTests(TestCase):
     @patch.dict(
         os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"}
     )
-    def test_https_request_via_https_proxy(self):
+    def test_https_request_via_https_proxy(self) -> None:
         """Tests that TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=None
@@ -459,7 +472,7 @@ class MatrixFederationAgentTests(TestCase):
         os.environ,
         {"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
     )
-    def test_https_request_via_https_proxy_with_auth(self):
+    def test_https_request_via_https_proxy_with_auth(self) -> None:
         """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
         self._do_https_request_via_proxy(
             expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies"
@@ -469,7 +482,7 @@ class MatrixFederationAgentTests(TestCase):
         self,
         expect_proxy_ssl: bool = False,
         expected_auth_credentials: Optional[bytes] = None,
-    ):
+    ) -> None:
         """Send a http request via an agent and check that it is correctly received at
             the proxy. The proxy can use either http or https.
         Args:
@@ -501,6 +514,7 @@ class MatrixFederationAgentTests(TestCase):
             tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None,
             expected_sni=b"proxy.com" if expect_proxy_ssl else None,
         )
+        assert isinstance(http_server, HTTPChannel)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -542,7 +556,7 @@ class MatrixFederationAgentTests(TestCase):
         self,
         expect_proxy_ssl: bool = False,
         expected_auth_credentials: Optional[bytes] = None,
-    ):
+    ) -> None:
         """Send a https request via an agent and check that it is correctly received at
             the proxy and client. The proxy can use either http or https.
         Args:
@@ -606,10 +620,11 @@ class MatrixFederationAgentTests(TestCase):
         # now we make another test server to act as the upstream HTTP server.
         server_ssl_protocol = _wrap_server_factory_for_tls(
             _get_test_protocol_factory()
-        ).buildProtocol(None)
+        ).buildProtocol(dummy_address)
 
         # Tell the HTTP server to send outgoing traffic back via the proxy's transport.
         proxy_server_transport = proxy_server.transport
+        assert proxy_server_transport is not None
         server_ssl_protocol.makeConnection(proxy_server_transport)
 
         # ... and replace the protocol on the proxy's transport with the
@@ -629,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
         else:
             assert isinstance(proxy_server_transport, FakeTransport)
             client_protocol = proxy_server_transport.other
-            c2s_transport = client_protocol.transport
+            assert isinstance(client_protocol, Protocol)
+            c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
             c2s_transport.other = server_ssl_protocol
 
         self.reactor.advance(0)
@@ -644,6 +660,7 @@ class MatrixFederationAgentTests(TestCase):
 
         # now there should be a pending request
         http_server = server_ssl_protocol.wrappedProtocol
+        assert isinstance(http_server, HTTPChannel)
         self.assertEqual(len(http_server.requests), 1)
 
         request = http_server.requests[0]
@@ -667,7 +684,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(body, b"result")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
-    def test_http_request_via_proxy_with_blacklist(self):
+    def test_http_request_via_proxy_with_blacklist(self) -> None:
         # The blacklist includes the configured proxy IP.
         agent = ProxyAgent(
             BlacklistingReactorWrapper(
@@ -691,6 +708,7 @@ class MatrixFederationAgentTests(TestCase):
         http_server = self._make_connection(
             client_factory, _get_test_protocol_factory()
         )
+        assert isinstance(http_server, HTTPChannel)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -712,7 +730,7 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(body, b"result")
 
     @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
-    def test_https_request_via_uppercase_proxy_with_blacklist(self):
+    def test_https_request_via_uppercase_proxy_with_blacklist(self) -> None:
         # The blacklist includes the configured proxy IP.
         agent = ProxyAgent(
             BlacklistingReactorWrapper(
@@ -737,11 +755,17 @@ class MatrixFederationAgentTests(TestCase):
         proxy_server = self._make_connection(
             client_factory, _get_test_protocol_factory()
         )
+        assert isinstance(proxy_server, HTTPChannel)
 
         # fish the transports back out so that we can do the old switcheroo
-        s2c_transport = proxy_server.transport
-        client_protocol = s2c_transport.other
-        c2s_transport = client_protocol.transport
+        # To help mypy out with the various Protocols and wrappers and mocks, we do
+        # some explicit casting. Without the casts, we hit the bug I reported at
+        # https://github.com/Shoobx/mypy-zope/issues/91 .
+        # We also double-checked these casts at runtime (test-time) because I found it
+        # quite confusing to deduce these types in the first place!
+        s2c_transport = checked_cast(FakeTransport, proxy_server.transport)
+        client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other)
+        c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
 
         # the FakeTransport is async, so we need to pump the reactor
         self.reactor.advance(0)
@@ -762,8 +786,10 @@ class MatrixFederationAgentTests(TestCase):
 
         # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
         ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
-        ssl_protocol = ssl_factory.buildProtocol(None)
+        ssl_protocol = ssl_factory.buildProtocol(dummy_address)
+        assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
         http_server = ssl_protocol.wrappedProtocol
+        assert isinstance(http_server, HTTPChannel)
 
         ssl_protocol.makeConnection(
             FakeTransport(client_protocol, self.reactor, ssl_protocol)
@@ -797,39 +823,35 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(body, b"result")
 
     @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
-    def test_proxy_with_no_scheme(self):
+    def test_proxy_with_no_scheme(self) -> None:
         http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
-        self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+        proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+        self.assertEqual(proxy_ep._hostStr, "proxy.com")
+        self.assertEqual(proxy_ep._port, 8888)
 
     @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
-    def test_proxy_with_unsupported_scheme(self):
+    def test_proxy_with_unsupported_scheme(self) -> None:
         with self.assertRaises(ValueError):
             ProxyAgent(self.reactor, use_proxy=True)
 
     @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
-    def test_proxy_with_http_scheme(self):
+    def test_proxy_with_http_scheme(self) -> None:
         http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
-        self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
-        self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
+        proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
+        self.assertEqual(proxy_ep._hostStr, "proxy.com")
+        self.assertEqual(proxy_ep._port, 8888)
 
     @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
-    def test_proxy_with_https_scheme(self):
+    def test_proxy_with_https_scheme(self) -> None:
         https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
-        self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
-        self.assertEqual(
-            https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
-        )
-        self.assertEqual(
-            https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
-        )
+        proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
+        self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
+        self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
 
 
 def _wrap_server_factory_for_tls(
-    factory: IProtocolFactory, sanlist: Iterable[bytes] = None
-) -> IProtocolFactory:
+    factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
+) -> TLSMemoryBIOFactory:
     """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
 
     The resultant factory will create a TLS server which presents a certificate
@@ -865,6 +887,6 @@ def _get_test_protocol_factory() -> IProtocolFactory:
     return server_factory
 
 
-def _log_request(request: str):
+def _log_request(request: str) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info(f"Completed request {request}")
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 46166292fe..c8d215b6dc 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -14,7 +14,7 @@
 import json
 from http import HTTPStatus
 from io import BytesIO
-from typing import Tuple
+from typing import Tuple, Union
 from unittest.mock import Mock
 
 from synapse.api.errors import Codes, SynapseError
@@ -33,7 +33,7 @@ from tests import unittest
 from tests.http.server._base import test_disconnect
 
 
-def make_request(content):
+def make_request(content: Union[bytes, JsonDict]) -> Mock:
     """Make an object that acts enough like a request."""
     request = Mock(spec=["method", "uri", "content"])
 
@@ -47,7 +47,7 @@ def make_request(content):
 
 
 class TestServletUtils(unittest.TestCase):
-    def test_parse_json_value(self):
+    def test_parse_json_value(self) -> None:
         """Basic tests for parse_json_value_from_request."""
         # Test round-tripping.
         obj = {"foo": 1}
@@ -78,7 +78,7 @@ class TestServletUtils(unittest.TestCase):
         with self.assertRaises(SynapseError):
             parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
 
-    def test_parse_json_object(self):
+    def test_parse_json_object(self) -> None:
         """Basic tests for parse_json_object_from_request."""
         # Test empty.
         result = parse_json_object_from_request(
diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index c85a3665c1..010601da4b 100644
--- a/tests/http/test_simple_client.py
+++ b/tests/http/test_simple_client.py
@@ -17,22 +17,24 @@ from netaddr import IPSet
 
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
 class SimpleHttpClientTests(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs: "HomeServer"):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: "HomeServer") -> None:
         # Add a DNS entry for a test server
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         self.cl = hs.get_simple_http_client()
 
-    def test_dns_error(self):
+    def test_dns_error(self) -> None:
         """
         If the DNS lookup returns an error, it will bubble up.
         """
@@ -42,7 +44,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
         f = self.failureResultOf(d)
         self.assertIsInstance(f.value, DNSLookupError)
 
-    def test_client_connection_refused(self):
+    def test_client_connection_refused(self) -> None:
         d = defer.ensureDeferred(self.cl.get_json("http://testserv:8008/foo/bar"))
 
         self.pump()
@@ -63,7 +65,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
 
         self.assertIs(f.value, e)
 
-    def test_client_never_connect(self):
+    def test_client_never_connect(self) -> None:
         """
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError or TimeoutError.
@@ -90,7 +92,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
 
         self.assertIsInstance(f.value, RequestTimedOutError)
 
-    def test_client_connect_no_response(self):
+    def test_client_connect_no_response(self) -> None:
         """
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
@@ -121,7 +123,7 @@ class SimpleHttpClientTests(HomeserverTestCase):
 
         self.assertIsInstance(f.value, RequestTimedOutError)
 
-    def test_client_ip_range_blacklist(self):
+    def test_client_ip_range_blacklist(self) -> None:
         """Ensure that Synapse does not try to connect to blacklisted IPs"""
 
         # Add some DNS entries we'll blacklist
diff --git a/tests/http/test_site.py b/tests/http/test_site.py
index b2dbf76d33..9a78fede92 100644
--- a/tests/http/test_site.py
+++ b/tests/http/test_site.py
@@ -13,18 +13,20 @@
 # limitations under the License.
 
 from twisted.internet.address import IPv6Address
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
 
 from synapse.app.homeserver import SynapseHomeServer
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
 class SynapseRequestTestCase(HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
 
-    def test_large_request(self):
+    def test_large_request(self) -> None:
         """overlarge HTTP requests should be rejected"""
         self.hs.start_listening()
 
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index c08954d887..5191e31a8a 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler
 from tests.logging import LoggerCleanupMixin
 from tests.server import FakeTransport, get_clock
 from tests.unittest import TestCase
+from tests.utils import checked_cast
 
 
 def connect_logging_client(
@@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         client, server = connect_logging_client(self.reactor, 0)
 
         # Trigger data being sent
-        assert isinstance(client.transport, FakeTransport)
-        client.transport.flush()
+        client_transport = checked_cast(FakeTransport, client.transport)
+        client_transport.flush()
 
         # One log message, with a single trailing newline
         logs = server.data.decode("utf8").splitlines()
@@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
-        assert isinstance(client.transport, FakeTransport)
-        client.transport.flush()
+        client_transport = checked_cast(FakeTransport, client.transport)
+        client_transport.flush()
 
         # Only the 7 infos made it through, the debugs were elided
         logs = server.data.splitlines()
@@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
-        assert isinstance(client.transport, FakeTransport)
-        client.transport.flush()
+        client_transport = checked_cast(FakeTransport, client.transport)
+        client_transport.flush()
 
         # The 10 warnings made it through, the debugs and infos were elided
         logs = server.data.splitlines()
@@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
-        assert isinstance(client.transport, FakeTransport)
-        client.transport.flush()
+        client_transport = checked_cast(FakeTransport, client.transport)
+        client_transport.flush()
 
         # The first five and last five warnings made it through, the debugs and
         # infos were elided
diff --git a/tests/rest/media/v1/__init__.py b/tests/media/__init__.py
index b1ee10cfcc..68910cbf5b 100644
--- a/tests/rest/media/v1/__init__.py
+++ b/tests/media/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2018 New Vector Ltd
+#  Copyright 2023 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/tests/rest/media/v1/test_base.py b/tests/media/test_base.py
index c73179151a..66498c744d 100644
--- a/tests/rest/media/v1/test_base.py
+++ b/tests/media/test_base.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.media._base import get_filename_from_headers
 
 from tests import unittest
 
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/media/test_filepath.py
index 43e6f0f70a..95e3b83d5a 100644
--- a/tests/rest/media/v1/test_filepath.py
+++ b/tests/media/test_filepath.py
@@ -15,7 +15,7 @@ import inspect
 import os
 from typing import Iterable
 
-from synapse.rest.media.v1.filepath import MediaFilePaths, _wrap_with_jail_check
+from synapse.media.filepath import MediaFilePaths, _wrap_with_jail_check
 
 from tests import unittest
 
diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/media/test_html_preview.py
index 1062081a06..e7da75db3e 100644
--- a/tests/rest/media/v1/test_html_preview.py
+++ b/tests/media/test_html_preview.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.rest.media.v1.preview_html import (
+from synapse.media.preview_html import (
     _get_html_media_encodings,
     decode_body,
     parse_html_to_open_graph,
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/media/test_media_storage.py
index d18fc13c21..870047d0f2 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
 import tempfile
 from binascii import unhexlify
 from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
 from unittest.mock import Mock
 from urllib import parse
 
@@ -32,16 +32,17 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.errors import Codes
 from synapse.events import EventBase
 from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
+from synapse.media._base import FileInfo
+from synapse.media.filepath import MediaFilePaths
+from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
+from synapse.media.storage_provider import FileStorageProviderBackend
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login
-from synapse.rest.media.v1._base import FileInfo
-from synapse.rest.media.v1.filepath import MediaFilePaths
-from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
-from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
 from synapse.util import Clock
 
 from tests import unittest
@@ -51,7 +52,6 @@ from tests.utils import default_config
 
 
 class MediaStorageTests(unittest.HomeserverTestCase):
-
     needs_threadpool = True
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -201,36 +201,45 @@ class _TestImage:
     ],
 )
 class MediaRepoTests(unittest.HomeserverTestCase):
-
+    test_image: ClassVar[_TestImage]
     hijack_auth = True
     user_id = "@test:user"
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
-        self.fetches = []
+        self.fetches: List[
+            Tuple[
+                "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+                str,
+                str,
+                Optional[QueryParams],
+            ]
+        ] = []
 
         def get_file(
             destination: str,
             path: str,
             output_stream: BinaryIO,
-            args: Optional[Dict[str, Union[str, List[str]]]] = None,
+            args: Optional[QueryParams] = None,
+            retry_on_dns_fail: bool = True,
             max_size: Optional[int] = None,
-        ) -> Deferred:
-            """
-            Returns tuple[int,dict,str,int] of file length, response headers,
-            absolute URI, and response code.
-            """
+            ignore_backoff: bool = False,
+        ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+            """A mock for MatrixFederationHttpClient.get_file."""
 
-            def write_to(r):
+            def write_to(
+                r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+            ) -> Tuple[int, Dict[bytes, List[bytes]]]:
                 data, response = r
                 output_stream.write(data)
                 return response
 
-            d = Deferred()
-            d.addCallback(write_to)
+            d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
             self.fetches.append((d, destination, path, args))
-            return make_deferred_yieldable(d)
+            # Note that this callback changes the value held by d.
+            d_after_callback = d.addCallback(write_to)
+            return make_deferred_yieldable(d_after_callback)
 
+        # Mock out the homeserver's MatrixFederationHttpClient
         client = Mock()
         client.get_file = get_file
 
@@ -244,7 +253,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         config["max_image_pixels"] = 2000000
 
         provider_config = {
-            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
             "store_local": True,
             "store_synchronous": False,
             "store_remote": True,
@@ -257,7 +266,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         media_resource = hs.get_media_repository_resource()
         self.download_resource = media_resource.children[b"download"]
         self.thumbnail_resource = media_resource.children[b"thumbnail"]
@@ -461,6 +469,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         # Synapse should regenerate missing thumbnails.
         origin, media_id = self.media_id.split("/")
         info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+        assert info is not None
         file_id = info["filesystem_id"]
 
         thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +590,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                         "thumbnail_method": method,
                         "thumbnail_type": self.test_image.content_type,
                         "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail1{self.test_image.extension}",
+                        "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
                     },
                     {
                         "thumbnail_width": 32,
@@ -589,10 +598,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                         "thumbnail_method": method,
                         "thumbnail_type": self.test_image.content_type,
                         "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail2{self.test_image.extension}",
+                        "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
                     },
                 ],
-                file_id=f"image{self.test_image.extension}",
+                file_id=f"image{self.test_image.extension.decode()}",
                 url_cache=None,
                 server_name=None,
             )
@@ -637,6 +646,7 @@ class TestSpamCheckerLegacy:
         self.config = config
         self.api = api
 
+    @staticmethod
     def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
         return config
 
@@ -748,7 +758,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
     async def check_media_file_for_spam(
         self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
-    ) -> Union[Codes, Literal["NOT_SPAM"]]:
+    ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
         buf = BytesIO()
         await file_wrapper.write_chunks_to(buf.write)
 
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/media/test_oembed.py
index 3f7f1dbab9..c8bf8421da 100644
--- a/tests/rest/media/v1/test_oembed.py
+++ b/tests/media/test_oembed.py
@@ -18,7 +18,7 @@ from parameterized import parameterized
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
+from synapse.media.oembed import OEmbedProvider, OEmbedResult
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..3a1929691e 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EduTypes, EventTypes
 from synapse.api.errors import NotFoundError
@@ -21,9 +23,12 @@ from synapse.events import EventBase
 from synapse.federation.units import Transaction
 from synapse.handlers.presence import UserPresenceState
 from synapse.handlers.push_rules import InvalidRuleException
+from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, notifications, presence, profile, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
 
 from tests.events.test_presence_router import send_presence_update, sync_presence
 from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
 
 
-class ModuleApiTestCase(HomeserverTestCase):
+class BaseModuleApiTestCase(HomeserverTestCase):
+    """Common properties of the two test case classes."""
+
+    module_api: ModuleApi
+
+    # These are all written by _test_sending_local_online_presence_to_local_user.
+    presence_receiver_id: str
+    presence_receiver_tok: str
+    presence_sender_id: str
+    presence_sender_tok: str
+
+
+class ModuleApiTestCase(BaseModuleApiTestCase):
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -42,23 +59,23 @@ class ModuleApiTestCase(HomeserverTestCase):
         notifications.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastores().main
-        self.module_api = homeserver.get_module_api()
-        self.event_creation_handler = homeserver.get_event_creation_handler()
-        self.sync_handler = homeserver.get_sync_handler()
-        self.auth_handler = homeserver.get_auth_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.module_api = hs.get_module_api()
+        self.event_creation_handler = hs.get_event_creation_handler()
+        self.sync_handler = hs.get_sync_handler()
+        self.auth_handler = hs.get_auth_handler()
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
-        fed_transport_client = Mock(spec=["send_transaction"])
-        fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client = Mock(spec=["send_transaction"])
+        self.fed_transport_client.send_transaction = simple_async_mock({})
 
         return self.setup_test_homeserver(
-            federation_transport_client=fed_transport_client,
+            federation_transport_client=self.fed_transport_client,
         )
 
-    def test_can_register_user(self):
+    def test_can_register_user(self) -> None:
         """Tests that an external module can register a user"""
         # Register a new user
         user_id, access_token = self.get_success(
@@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
 
-    def test_can_register_admin_user(self):
+    def test_can_register_admin_user(self) -> None:
         user_id = self.register_user(
             "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
         )
 
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
 
-    def test_can_set_admin(self):
+    def test_can_set_admin(self) -> None:
         user_id = self.register_user(
             "alice_wants_admin",
             "1234",
@@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.get_success(self.module_api.set_user_admin(user_id, True))
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
 
-    def test_can_set_displayname(self):
+    def test_can_set_displayname(self) -> None:
         localpart = "alice_wants_a_new_displayname"
         user_id = self.register_user(
             localpart, "1234", displayname="Alice", admin=False
         )
         found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
-
+        assert found_userinfo is not None
         self.get_success(
             self.module_api.set_displayname(
                 found_userinfo.user_id, "Bob", deactivation=False
@@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.assertEqual(found_profile.display_name, "Bob")
 
-    def test_get_userinfo_by_id(self):
+    def test_get_userinfo_by_id(self) -> None:
         user_id = self.register_user("alice", "1234")
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, False)
 
-    def test_get_userinfo_by_id__no_user_found(self):
+    def test_get_userinfo_by_id__no_user_found(self) -> None:
         found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
         self.assertIsNone(found_user)
 
-    def test_get_user_ip_and_agents(self):
+    def test_get_user_ip_and_agents(self) -> None:
         user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
 
         # Initially, we should have no ip/agent for our user.
@@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # we should only find the second ip, agent.
         info = self.get_success(
             self.module_api.get_user_ip_and_agents(
-                user_id, (last_seen_1 + last_seen_2) / 2
+                user_id, (last_seen_1 + last_seen_2) // 2
             )
         )
         self.assertEqual(len(info), 1)
@@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertEqual(info, [])
 
-    def test_get_user_ip_and_agents__no_user_found(self):
+    def test_get_user_ip_and_agents__no_user_found(self) -> None:
         info = self.get_success(
             self.module_api.get_user_ip_and_agents(
                 "@test_get_user_ip_and_agents_user_nonexistent:example.com"
@@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertEqual(info, [])
 
-    def test_sending_events_into_room(self):
+    def test_sending_events_into_room(self) -> None:
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent
-        self.event_creation_handler.create_and_send_nonmember_event = Mock(
+        self.event_creation_handler.create_and_send_nonmember_event = Mock(  # type: ignore[assignment]
             spec=[],
             side_effect=self.event_creation_handler.create_and_send_nonmember_event,
         )
@@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         room_id = self.helper.create_room_as(user_id, tok=tok)
 
         # Create and send a non-state event
-        content = {"body": "I am a puppet", "msgtype": "m.text"}
+        content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
         event_dict = {
             "room_id": room_id,
             "type": "m.room.message",
@@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             "sender": user_id,
             "state_key": "",
         }
-        event: EventBase = self.get_success(
+        event = self.get_success(
             self.module_api.create_and_send_event_into_room(event_dict)
         )
         self.assertEqual(event.sender, user_id)
@@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             self.module_api.create_and_send_event_into_room(event_dict), Exception
         )
 
-    def test_public_rooms(self):
+    def test_public_rooms(self) -> None:
         """Tests that a room can be added and removed from the public rooms list,
         as well as have its public rooms directory state queried.
         """
@@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertFalse(is_in_public_rooms)
 
-    def test_send_local_online_presence_to(self):
+    def test_send_local_online_presence_to(self) -> None:
         # Test sending local online presence to users from the main process
         _test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
 
     # Enable federation sending on the main process.
     @override_config({"federation_sender_instances": None})
-    def test_send_local_online_presence_to_federation(self):
+    def test_send_local_online_presence_to_federation(self) -> None:
         """Tests that send_local_presence_to_users sends local online presence to remote users."""
         # Create a user who will send presence updates
         self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@@ -397,7 +417,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         #
         # Thus we reset the mock, and try sending online local user
         # presence again
-        self.hs.get_federation_transport_client().send_transaction.reset_mock()
+        self.fed_transport_client.send_transaction.reset_mock()
 
         # Broadcast local user online presence
         self.get_success(
@@ -409,9 +429,7 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         # Check that a presence update was sent as part of a federation transaction
         found_update = False
-        calls = (
-            self.hs.get_federation_transport_client().send_transaction.call_args_list
-        )
+        calls = self.fed_transport_client.send_transaction.call_args_list
         for call in calls:
             call_args = call[0]
             federation_transaction: Transaction = call_args[0]
@@ -431,7 +449,7 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.assertTrue(found_update)
 
-    def test_update_membership(self):
+    def test_update_membership(self) -> None:
         """Tests that the module API can update the membership of a user in a room."""
         peter = self.register_user("peter", "hackme")
         lesley = self.register_user("lesley", "hackme")
@@ -554,14 +572,14 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertEqual(res["displayname"], "simone")
         self.assertIsNone(res["avatar_url"])
 
-    def test_update_room_membership_remote_join(self):
+    def test_update_room_membership_remote_join(self) -> None:
         """Test that the module API can join a remote room."""
         # Necessary to fake a remote join.
         fake_stream_id = 1
         mocked_remote_join = simple_async_mock(
             return_value=("fake-event-id", fake_stream_id)
         )
-        self.hs.get_room_member_handler()._remote_join = mocked_remote_join
+        self.hs.get_room_member_handler()._remote_join = mocked_remote_join  # type: ignore[assignment]
         fake_remote_host = f"{self.module_api.server_name}-remote"
 
         # Given that the join is to be faked, we expect the relevant join event not to
@@ -582,7 +600,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Check that a remote join was attempted.
         self.assertEqual(mocked_remote_join.call_count, 1)
 
-    def test_get_room_state(self):
+    def test_get_room_state(self) -> None:
         """Tests that a module can retrieve the state of a room through the module API."""
         user_id = self.register_user("peter", "hackme")
         tok = self.login("peter", "hackme")
@@ -677,7 +695,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             self.module_api.check_push_rule_actions(["foo"])
 
         with self.assertRaises(InvalidRuleException):
-            self.module_api.check_push_rule_actions({"foo": "bar"})
+            self.module_api.check_push_rule_actions([{"foo": "bar"}])
 
         self.module_api.check_push_rule_actions(["notify"])
 
@@ -756,7 +774,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertIsNone(room_alias)
 
 
-class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
     """For testing ModuleApi functionality in a multi-worker setup"""
 
     servlets = [
@@ -766,7 +784,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
         presence.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         conf = super().default_config()
         conf["stream_writers"] = {"presence": ["presence_writer"]}
         conf["instance_map"] = {
@@ -774,18 +792,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
         }
         return conf
 
-    def prepare(self, reactor, clock, homeserver):
-        self.module_api = homeserver.get_module_api()
-        self.sync_handler = homeserver.get_sync_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.module_api = hs.get_module_api()
+        self.sync_handler = hs.get_sync_handler()
 
-    def test_send_local_online_presence_to_workers(self):
+    def test_send_local_online_presence_to_workers(self) -> None:
         # Test sending local online presence to users from a worker process
         _test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
 
 
 def _test_sending_local_online_presence_to_local_user(
-    test_case: HomeserverTestCase, test_with_workers: bool = False
-):
+    test_case: BaseModuleApiTestCase, test_with_workers: bool = False
+) -> None:
     """Tests that send_local_presence_to_users sends local online presence to local users.
 
     This simultaneously tests two different usecases:
@@ -852,6 +870,7 @@ def _test_sending_local_online_presence_to_local_user(
         # Replicate the current sync presence token from the main process to the worker process.
         # We need to do this so that the worker process knows the current presence stream ID to
         # insert into the database when we call ModuleApi.send_local_online_presence_to.
+        assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
         test_case.replicate()
 
     # Syncing again should result in no presence updates
@@ -868,6 +887,7 @@ def _test_sending_local_online_presence_to_local_user(
 
     # Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
     if test_with_workers:
+        assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
         module_api_to_use = worker_hs.get_module_api()
     else:
         module_api_to_use = test_case.module_api
@@ -875,12 +895,11 @@ def _test_sending_local_online_presence_to_local_user(
     # Trigger sending local online presence. We expect this information
     # to be saved to the database where all processes can access it.
     # Note that we're syncing via the master.
-    d = module_api_to_use.send_local_online_presence_to(
-        [
-            test_case.presence_receiver_id,
-        ]
+    d = defer.ensureDeferred(
+        module_api_to_use.send_local_online_presence_to(
+            [test_case.presence_receiver_id],
+        )
     )
-    d = defer.ensureDeferred(d)
 
     if test_with_workers:
         # In order for the required presence_set_state replication request to occur between the
@@ -897,7 +916,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update: UserPresenceState = presence_updates[0]
+    presence_update = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -908,7 +927,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update: UserPresenceState = presence_updates[0]
+    presence_update = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -936,12 +955,13 @@ def _test_sending_local_online_presence_to_local_user(
     test_case.assertEqual(len(presence_updates), 1)
 
     # Now trigger sending local online presence.
-    d = module_api_to_use.send_local_online_presence_to(
-        [
-            test_case.presence_receiver_id,
-        ]
+    d = defer.ensureDeferred(
+        module_api_to_use.send_local_online_presence_to(
+            [
+                test_case.presence_receiver_id,
+            ]
+        )
     )
-    d = defer.ensureDeferred(d)
 
     if test_with_workers:
         # In order for the required presence_set_state replication request to occur between the
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 7567756135..46df0102f7 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -33,7 +33,6 @@ from tests.unittest import HomeserverTestCase, override_config
 
 
 class TestBulkPushRuleEvaluator(HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -131,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
 
         # Create a new message event, and try to evaluate it under the dodgy
         # power level event.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -146,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 prev_event_ids=[pl_event_id],
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
         # should not raise
@@ -171,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         """Ensure that push rules are not calculated when disabled in the config"""
 
         # Create a new message event which should cause a notification.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -185,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 },
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
         # Mock the method which calculates push rules -- we do this instead of
@@ -201,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
     ) -> bool:
         """Returns true iff the `mentions` trigger an event push action."""
         # Create a new message event which should cause a notification.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -212,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 },
             )
         )
-
+        context = self.get_success(unpersisted_context.persist(event))
         # Execute the push rule machinery.
         self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
 
@@ -377,7 +378,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
 
         # Create & persist an event to use as the parent of the relation.
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_event(
                 self.requester,
                 {
@@ -391,6 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
                 },
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             self.event_creation_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index ab8bb417e7..4ea5472eb4 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.errors import Codes, SynapseError
+from synapse.push.emailpusher import EmailPusher
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.util import Clock
@@ -38,7 +39,6 @@ class _User:
 
 
 class EmailPusherTests(HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -47,7 +47,6 @@ class EmailPusherTests(HomeserverTestCase):
     hijack_auth = False
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["email"] = {
             "enable_notifs": True,
@@ -105,6 +104,7 @@ class EmailPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
         )
+        assert user_tuple is not None
         self.token_id = user_tuple.token_id
 
         # We need to add email to account before we can create a pusher.
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
             )
         )
 
-        self.pusher = self.get_success(
+        pusher = self.get_success(
             self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.user_id,
                 access_token=self.token_id,
@@ -127,6 +127,8 @@ class EmailPusherTests(HomeserverTestCase):
                 data={},
             )
         )
+        assert isinstance(pusher, EmailPusher)
+        self.pusher = pusher
 
         self.auth_handler = hs.get_auth_handler()
         self.store = hs.get_datastores().main
@@ -367,18 +369,19 @@ class EmailPusherTests(HomeserverTestCase):
 
         # disassociate the user's email address
         self.get_success(
-            self.auth_handler.delete_threepid(
-                user_id=self.user_id,
-                medium="email",
-                address="a@example.com",
+            self.auth_handler.delete_local_threepid(
+                user_id=self.user_id, medium="email", address="a@example.com"
             )
         )
 
         # check that the pusher for that email address has been deleted
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
 
     def test_remove_unlinked_pushers_background_job(self) -> None:
@@ -413,10 +416,13 @@ class EmailPusherTests(HomeserverTestCase):
         self.wait_for_background_updates()
 
         # Check that all pushers with unlinked addresses were deleted
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
 
     def _check_for_mail(self) -> Tuple[Sequence, Dict]:
@@ -428,10 +434,13 @@ class EmailPusherTests(HomeserverTestCase):
             that notification.
         """
         # Get the stream ordering before it gets sent
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         last_stream_ordering = pushers[0].last_stream_ordering
 
@@ -439,10 +448,13 @@ class EmailPusherTests(HomeserverTestCase):
         self.pump(10)
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
@@ -458,10 +470,13 @@ class EmailPusherTests(HomeserverTestCase):
         self.assertEqual(len(self.email_attempts), 1)
 
         # The stream ordering has increased
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by(
+                    {"user_name": self.user_id}
+                )
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 23447cc310..c280ddcdf6 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Tuple
 from unittest.mock import Mock
 
 from twisted.internet.defer import Deferred
@@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.push import PusherConfig, PusherConfigException
 from synapse.rest.client import login, push_rule, pusher, receipts, room
 from synapse.server import HomeServer
-from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import JsonDict
 from synapse.util import Clock
 
@@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
-        def test_data(data: Optional[JsonDict]) -> None:
+        def test_data(data: Any) -> None:
             self.get_failure(
                 self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=user_id,
@@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.helper.send(room, body="There!", tok=other_access_token)
 
         # Get the stream ordering before it gets sent
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         last_stream_ordering = pushers[0].last_stream_ordering
 
@@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.pump()
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
@@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.pump()
 
         # The stream ordering has increased
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
         last_stream_ordering = pushers[0].last_stream_ordering
@@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
         self.pump()
 
         # The stream ordering has increased, again
-        pushers = self.get_success(
-            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        pushers = list(
+            self.get_success(
+                self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+            )
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
@@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
         device_id = user_tuple.device_id
 
@@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
         )
 
         # Look up the user info for the access token so we can compare the device ID.
-        lookup_result: TokenLookupResult = self.get_success(
+        lookup_result = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert lookup_result is not None
 
         # Get the user's devices and check it has the correct device ID.
         channel = self.make_request("GET", "/pushers", access_token=access_token)
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 7c430c4ecb..52c4aafea6 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List, Optional, Set, Union, cast
+from typing import Any, Dict, List, Optional, Union, cast
 
 import frozendict
 
@@ -22,7 +22,7 @@ import synapse.rest.admin
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.appservice import ApplicationService
-from synapse.events import FrozenEvent
+from synapse.events import FrozenEvent, make_event_from_dict
 from synapse.push.bulk_push_rule_evaluator import _flatten_dict
 from synapse.push.httppusher import tweaks_for_actions
 from synapse.rest import admin
@@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex
 from synapse.synapse_rust.push import PushRuleEvaluator
 from synapse.types import JsonDict, JsonMapping, UserID
 from synapse.util import Clock
+from synapse.util.frozenutils import freeze
 
 from tests import unittest
 from tests.test_utils.event_injection import create_event, inject_member_event
@@ -48,17 +49,93 @@ class FlattenDictTestCase(unittest.TestCase):
         input = {"foo": {"bar": "abc"}}
         self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input))
 
+        # If a field has a dot in it, escape it.
+        input = {"m.foo": {"b\\ar": "abc"}}
+        self.assertEqual({"m\\.foo.b\\\\ar": "abc"}, _flatten_dict(input))
+
     def test_non_string(self) -> None:
-        """Non-string items are dropped."""
+        """String, booleans, ints, nulls and list of those should be kept while other items are dropped."""
         input: Dict[str, Any] = {
             "woo": "woo",
             "foo": True,
             "bar": 1,
             "baz": None,
-            "fuzz": [],
+            "fuzz": ["woo", True, 1, None, [], {}],
             "boo": {},
         }
-        self.assertEqual({"woo": "woo"}, _flatten_dict(input))
+        self.assertEqual(
+            {
+                "woo": "woo",
+                "foo": True,
+                "bar": 1,
+                "baz": None,
+                "fuzz": ["woo", True, 1, None],
+            },
+            _flatten_dict(input),
+        )
+
+    def test_event(self) -> None:
+        """Events can also be flattened."""
+        event = make_event_from_dict(
+            {
+                "room_id": "!test:test",
+                "type": "m.room.message",
+                "sender": "@alice:test",
+                "content": {
+                    "msgtype": "m.text",
+                    "body": "Hello world!",
+                    "format": "org.matrix.custom.html",
+                    "formatted_body": "<h1>Hello world!</h1>",
+                },
+            },
+            room_version=RoomVersions.V8,
+        )
+        expected = {
+            "content.msgtype": "m.text",
+            "content.body": "Hello world!",
+            "content.format": "org.matrix.custom.html",
+            "content.formatted_body": "<h1>Hello world!</h1>",
+            "room_id": "!test:test",
+            "sender": "@alice:test",
+            "type": "m.room.message",
+        }
+        self.assertEqual(expected, _flatten_dict(event))
+
+    def test_extensible_events(self) -> None:
+        """Extensible events has compatibility behaviour."""
+        event_dict = {
+            "room_id": "!test:test",
+            "type": "m.room.message",
+            "sender": "@alice:test",
+            "content": {
+                "org.matrix.msc1767.markup": [
+                    {"mimetype": "text/plain", "body": "Hello world!"},
+                    {"mimetype": "text/html", "body": "<h1>Hello world!</h1>"},
+                ]
+            },
+        }
+
+        # For a current room version, there's no special behavior.
+        event = make_event_from_dict(event_dict, room_version=RoomVersions.V8)
+        expected = {
+            "room_id": "!test:test",
+            "sender": "@alice:test",
+            "type": "m.room.message",
+            "content.org\\.matrix\\.msc1767\\.markup": [],
+        }
+        self.assertEqual(expected, _flatten_dict(event))
+
+        # For a room version with extensible events, they parse out the text/plain
+        # to a content.body property.
+        event = make_event_from_dict(event_dict, room_version=RoomVersions.MSC1767v10)
+        expected = {
+            "content.body": "hello world!",
+            "room_id": "!test:test",
+            "sender": "@alice:test",
+            "type": "m.room.message",
+            "content.org\\.matrix\\.msc1767\\.markup": [],
+        }
+        self.assertEqual(expected, _flatten_dict(event))
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
@@ -66,9 +143,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         self,
         content: JsonMapping,
         *,
-        has_mentions: bool = False,
-        user_mentions: Optional[Set[str]] = None,
-        room_mention: bool = False,
         related_events: Optional[JsonDict] = None,
     ) -> PushRuleEvaluator:
         event = FrozenEvent(
@@ -87,9 +161,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
         return PushRuleEvaluator(
             _flatten_dict(event),
-            has_mentions,
-            user_mentions or set(),
-            room_mention,
+            False,
             room_member_count,
             sender_power_level,
             cast(Dict[str, int], power_levels.get("notifications", {})),
@@ -123,53 +195,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         # A display name with spaces should work fine.
         self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
 
-    def test_user_mentions(self) -> None:
-        """Check for user mentions."""
-        condition = {"kind": "org.matrix.msc3952.is_user_mention"}
-
-        # No mentions shouldn't match.
-        evaluator = self._get_evaluator({}, has_mentions=True)
-        self.assertFalse(evaluator.matches(condition, "@user:test", None))
-
-        # An empty set shouldn't match
-        evaluator = self._get_evaluator({}, has_mentions=True, user_mentions=set())
-        self.assertFalse(evaluator.matches(condition, "@user:test", None))
-
-        # The Matrix ID appearing anywhere in the mentions list should match
-        evaluator = self._get_evaluator(
-            {}, has_mentions=True, user_mentions={"@user:test"}
-        )
-        self.assertTrue(evaluator.matches(condition, "@user:test", None))
-
-        evaluator = self._get_evaluator(
-            {}, has_mentions=True, user_mentions={"@another:test", "@user:test"}
-        )
-        self.assertTrue(evaluator.matches(condition, "@user:test", None))
-
-        # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
-        # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
-    def test_room_mentions(self) -> None:
-        """Check for room mentions."""
-        condition = {"kind": "org.matrix.msc3952.is_room_mention"}
-
-        # No room mention shouldn't match.
-        evaluator = self._get_evaluator({}, has_mentions=True)
-        self.assertFalse(evaluator.matches(condition, None, None))
-
-        # Room mention should match.
-        evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True)
-        self.assertTrue(evaluator.matches(condition, None, None))
-
-        # A room mention and user mention is valid.
-        evaluator = self._get_evaluator(
-            {}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True
-        )
-        self.assertTrue(evaluator.matches(condition, None, None))
-
-        # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
-        # since the BulkPushRuleEvaluator is what handles data sanitisation.
-
     def _assert_matches(
         self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
     ) -> None:
@@ -341,6 +366,193 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             "pattern should not match before a newline",
         )
 
+    def test_event_match_pattern(self) -> None:
+        """Check that event_match conditions do not use a "pattern_type" from user data."""
+
+        # The pattern_type should not be deserialized into anything valid.
+        condition = {
+            "kind": "event_match",
+            "key": "content.value",
+            "pattern_type": "user_id",
+        }
+        self._assert_not_matches(
+            condition,
+            {"value": "@user:test"},
+            "should not be possible to pass a pattern_type in",
+        )
+
+        # This is an internal-only condition which shouldn't get deserialized.
+        condition = {
+            "kind": "event_match_type",
+            "key": "content.value",
+            "pattern_type": "user_id",
+        }
+        self._assert_not_matches(
+            condition,
+            {"value": "@user:test"},
+            "should not be possible to pass a pattern_type in",
+        )
+
+    def test_exact_event_match_string(self) -> None:
+        """Check that exact_event_match conditions work as expected for strings."""
+
+        # Test against a string value.
+        condition = {
+            "kind": "event_property_is",
+            "key": "content.value",
+            "value": "foobaz",
+        }
+        self._assert_matches(
+            condition,
+            {"value": "foobaz"},
+            "exact value should match",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "FoobaZ"},
+            "values should match and be case-sensitive",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "test foobaz test"},
+            "values must exactly match",
+        )
+        value: Any
+        for value in (True, False, 1, 1.1, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+        # it should work on frozendicts too
+        self._assert_matches(
+            condition,
+            frozendict.frozendict({"value": "foobaz"}),
+            "values should match on frozendicts",
+        )
+
+    def test_exact_event_match_boolean(self) -> None:
+        """Check that exact_event_match conditions work as expected for booleans."""
+
+        # Test against a True boolean value.
+        condition = {"kind": "event_property_is", "key": "content.value", "value": True}
+        self._assert_matches(
+            condition,
+            {"value": True},
+            "exact value should match",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": False},
+            "incorrect values should not match",
+        )
+        for value in ("foobaz", 1, 1.1, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+        # Test against a False boolean value.
+        condition = {
+            "kind": "event_property_is",
+            "key": "content.value",
+            "value": False,
+        }
+        self._assert_matches(
+            condition,
+            {"value": False},
+            "exact value should match",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": True},
+            "incorrect values should not match",
+        )
+        # Choose false-y values to ensure there's no type coercion.
+        for value in ("", 0, 1.1, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+    def test_exact_event_match_null(self) -> None:
+        """Check that exact_event_match conditions work as expected for null."""
+
+        condition = {"kind": "event_property_is", "key": "content.value", "value": None}
+        self._assert_matches(
+            condition,
+            {"value": None},
+            "exact value should match",
+        )
+        for value in ("foobaz", True, False, 1, 1.1, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+    def test_exact_event_match_integer(self) -> None:
+        """Check that exact_event_match conditions work as expected for integers."""
+
+        condition = {"kind": "event_property_is", "key": "content.value", "value": 1}
+        self._assert_matches(
+            condition,
+            {"value": 1},
+            "exact value should match",
+        )
+        value: Any
+        for value in (1.1, -1, 0):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect values should not match",
+            )
+        for value in ("1", True, False, None, [], {}):
+            self._assert_not_matches(
+                condition,
+                {"value": value},
+                "incorrect types should not match",
+            )
+
+    def test_exact_event_property_contains(self) -> None:
+        """Check that exact_event_property_contains conditions work as expected."""
+
+        condition = {
+            "kind": "event_property_contains",
+            "key": "content.value",
+            "value": "foobaz",
+        }
+        self._assert_matches(
+            condition,
+            {"value": ["foobaz"]},
+            "exact value should match",
+        )
+        self._assert_matches(
+            condition,
+            {"value": ["foobaz", "bugz"]},
+            "extra values should match",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": ["FoobaZ"]},
+            "values should match and be case-sensitive",
+        )
+        self._assert_not_matches(
+            condition,
+            {"value": "foobaz"},
+            "does not search in a string",
+        )
+
+        # it should work on frozendicts too
+        self._assert_matches(
+            condition,
+            freeze({"value": ["foobaz"]}),
+            "values should match on frozendicts",
+        )
+
     def test_no_body(self) -> None:
         """Not having a body shouldn't break the evaluator."""
         evaluator = self._get_evaluator({})
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 6a7174b333..46a8e2013e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -16,7 +16,9 @@ from collections import defaultdict
 from typing import Any, Dict, List, Optional, Set, Tuple
 
 from twisted.internet.address import IPv4Address
-from twisted.internet.protocol import Protocol
+from twisted.internet.protocol import Protocol, connectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
@@ -30,6 +32,7 @@ from synapse.replication.tcp.protocol import (
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -51,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
@@ -92,8 +95,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
             repl_handler,
         )
 
-        self._client_transport = None
-        self._server_transport = None
+        self._client_transport: Optional[FakeTransport] = None
+        self._server_transport: Optional[FakeTransport] = None
 
     def create_resource_dict(self) -> Dict[str, Resource]:
         d = super().create_resource_dict()
@@ -107,10 +110,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def _build_replication_data_handler(self):
+    def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
         return TestReplicationDataHandler(self.worker_hs)
 
-    def reconnect(self):
+    def reconnect(self) -> None:
         if self._client_transport:
             self.client.close()
 
@@ -123,7 +126,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._server_transport = FakeTransport(self.client, self.reactor)
         self.server.makeConnection(self._server_transport)
 
-    def disconnect(self):
+    def disconnect(self) -> None:
         if self._client_transport:
             self._client_transport = None
             self.client.close()
@@ -132,7 +135,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
             self._server_transport = None
             self.server.close()
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
@@ -168,7 +171,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         requests: List[SynapseRequest] = []
         real_request_factory = channel.requestFactory
 
-        def request_factory(*args, **kwargs):
+        def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
             request = real_request_factory(*args, **kwargs)
             requests.append(request)
             return request
@@ -202,7 +205,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
     def assert_request_is_get_repl_stream_updates(
         self, request: SynapseRequest, stream_name: str
-    ):
+    ) -> None:
         """Asserts that the given request is a HTTP replication request for
         fetching updates for given stream.
         """
@@ -244,7 +247,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         base["redis"] = {"enabled": True}
         return base
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
 
         # build a replication server
@@ -287,7 +290,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             lambda: self._handle_http_replication_attempt(self.hs, 8765),
         )
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> ReplicationRestResource:
         """Overrides `HomeserverTestCase.create_test_resource`."""
         # We override this so that it automatically registers all the HTTP
         # replication servlets, without having to explicitly do that in all
@@ -301,7 +304,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         return resource
 
     def make_worker_hs(
-        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
+        self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
     ) -> HomeServer:
         """Make a new worker HS instance, correctly connecting replcation
         stream to the master HS.
@@ -385,14 +388,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
         self.streamer.on_notifier_poke()
         self.pump()
 
-    def _handle_http_replication_attempt(self, hs, repl_port):
+    def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
         """Handles a connection attempt to the given HS replication HTTP
         listener on the given port.
         """
@@ -429,7 +432,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # inside `connecTCP` before the connection has been passed back to the
         # code that requested the TCP connection.
 
-    def connect_any_redis_attempts(self):
+    def connect_any_redis_attempts(self) -> None:
         """If redis is enabled we need to deal with workers connecting to a
         redis server. We don't want to use a real Redis server so we use a
         fake one.
@@ -440,8 +443,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             self.assertEqual(host, "localhost")
             self.assertEqual(port, 6379)
 
-            client_protocol = client_factory.buildProtocol(None)
-            server_protocol = self._redis_server.buildProtocol(None)
+            client_address = IPv4Address("TCP", "127.0.0.1", 6379)
+            client_protocol = client_factory.buildProtocol(client_address)
+
+            server_address = IPv4Address("TCP", host, port)
+            server_protocol = self._redis_server.buildProtocol(server_address)
 
             client_to_server_transport = FakeTransport(
                 server_protocol, self.reactor, client_protocol
@@ -463,7 +469,9 @@ class TestReplicationDataHandler(ReplicationDataHandler):
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows: List[Tuple[str, int, Any]] = []
 
-    async def on_rdata(self, stream_name, instance_name, token, rows):
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ) -> None:
         await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
@@ -472,28 +480,30 @@ class TestReplicationDataHandler(ReplicationDataHandler):
 class FakeRedisPubSubServer:
     """A fake Redis server for pub/sub."""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._subscribers_by_channel: Dict[
             bytes, Set["FakeRedisPubSubProtocol"]
         ] = defaultdict(set)
 
-    def add_subscriber(self, conn, channel: bytes):
+    def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
         """A connection has called SUBSCRIBE"""
         self._subscribers_by_channel[channel].add(conn)
 
-    def remove_subscriber(self, conn):
+    def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
         """A connection has lost connection"""
         for subscribers in self._subscribers_by_channel.values():
             subscribers.discard(conn)
 
-    def publish(self, conn, channel: bytes, msg) -> int:
+    def publish(
+        self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
+    ) -> int:
         """A connection want to publish a message to subscribers."""
         for sub in self._subscribers_by_channel[channel]:
             sub.send(["message", channel, msg])
 
         return len(self._subscribers_by_channel)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
         return FakeRedisPubSubProtocol(self)
 
 
@@ -506,7 +516,7 @@ class FakeRedisPubSubProtocol(Protocol):
         self._server = server
         self._reader = hiredis.Reader()
 
-    def dataReceived(self, data):
+    def dataReceived(self, data: bytes) -> None:
         self._reader.feed(data)
 
         # We might get multiple messages in one packet.
@@ -523,7 +533,7 @@ class FakeRedisPubSubProtocol(Protocol):
 
             self.handle_command(msg[0], *msg[1:])
 
-    def handle_command(self, command, *args):
+    def handle_command(self, command: bytes, *args: bytes) -> None:
         """Received a Redis command from the client."""
 
         # We currently only support pub/sub.
@@ -548,9 +558,9 @@ class FakeRedisPubSubProtocol(Protocol):
             self.send("PONG")
 
         else:
-            raise Exception(f"Unknown command: {command}")
+            raise Exception(f"Unknown command: {command!r}")
 
-    def send(self, msg):
+    def send(self, msg: object) -> None:
         """Send a message back to the client."""
         assert self.transport is not None
 
@@ -559,7 +569,7 @@ class FakeRedisPubSubProtocol(Protocol):
         self.transport.write(raw)
         self.transport.flush()
 
-    def encode(self, obj):
+    def encode(self, obj: object) -> str:
         """Encode an object to its Redis format.
 
         Supports: strings/bytes, integers and list/tuples.
@@ -581,5 +591,5 @@ class FakeRedisPubSubProtocol(Protocol):
 
         raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         self._server.remove_subscriber(self)
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index e03d9b4cc0..9be11ab802 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -74,7 +74,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
 class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
     """Tests for `ReplicationEndpoint` cancellation."""
 
-    def create_test_resource(self):
+    def create_test_resource(self) -> JsonResource:
         """Overrides `HomeserverTestCase.create_test_resource`."""
         resource = JsonResource(self.hs)
 
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index c5705256e6..4c9b494344 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -13,35 +13,42 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Any, Iterable, Optional
 from unittest.mock import Mock
 
-from tests.replication._base import BaseStreamTestCase
+from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.server import HomeServer
+from synapse.util import Clock
 
-class BaseSlavedStoreTestCase(BaseStreamTestCase):
-    def make_homeserver(self, reactor, clock):
+from tests.replication._base import BaseStreamTestCase
 
-        hs = self.setup_test_homeserver(federation_client=Mock())
 
-        return hs
+class BaseSlavedStoreTestCase(BaseStreamTestCase):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        return self.setup_test_homeserver(federation_client=Mock())
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
 
         self.reconnect()
 
         self.master_store = hs.get_datastores().main
         self.slaved_store = self.worker_hs.get_datastores().main
-        self._storage_controllers = hs.get_storage_controllers()
+        persistence = hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self.persistance = persistence
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
         self.streamer.on_notifier_poke()
         self.pump(0.1)
 
-    def check(self, method, args, expected_result=None):
+    def check(
+        self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
+    ) -> None:
         master_result = self.get_success(getattr(self.master_store, method)(*args))
         slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
         if expected_result is not None:
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index dce71f7334..57c781a0c3 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,15 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Iterable, Optional
+from typing import Any, Callable, Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import ReceiptTypes
 from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.handlers.room import RoomEventSource
+from synapse.server import HomeServer
 from synapse.storage.databases.main.event_push_actions import (
     NotifCounts,
     RoomNotifCounts,
@@ -28,6 +32,7 @@ from synapse.storage.databases.main.event_push_actions import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
 from synapse.types import PersistedEventPosition
+from synapse.util import Clock
 
 from tests.server import FakeTransport
 
@@ -41,35 +46,34 @@ ROOM_ID = "!room:test"
 logger = logging.getLogger(__name__)
 
 
-def dict_equals(self, other):
+def dict_equals(self: EventBase, other: EventBase) -> bool:
     me = encode_canonical_json(self.get_pdu_json())
     them = encode_canonical_json(other.get_pdu_json())
     return me == them
 
 
-def patch__eq__(cls):
+def patch__eq__(cls: object) -> Callable[[], None]:
     eq = getattr(cls, "__eq__", None)
-    cls.__eq__ = dict_equals
+    cls.__eq__ = dict_equals  # type: ignore[assignment]
 
-    def unpatch():
+    def unpatch() -> None:
         if eq is not None:
-            cls.__eq__ = eq
+            cls.__eq__ = eq  # type: ignore[assignment]
 
     return unpatch
 
 
 class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
-
     STORE_TYPE = EventsWorkerStore
 
-    def setUp(self):
+    def setUp(self) -> None:
         # Patch up the equality operator for events so that we can check
         # whether lists of events match using assertEqual
-        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
-        return super().setUp()
+        self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
+        super().setUp()
 
-    def prepare(self, *args, **kwargs):
-        super().prepare(*args, **kwargs)
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        super().prepare(reactor, clock, hs)
 
         self.get_success(
             self.master_store.store_room(
@@ -80,10 +84,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             )
         )
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         [unpatch() for unpatch in self.unpatches]
 
-    def test_get_latest_event_ids_in_room(self):
+    def test_get_latest_event_ids_in_room(self) -> None:
         create = self.persist(type="m.room.create", key="", creator=USER_ID)
         self.replicate()
         self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
@@ -97,7 +101,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         self.replicate()
         self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
 
-    def test_redactions(self):
+    def test_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
 
@@ -117,7 +121,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
         self.check("get_event", [msg.event_id], redacted)
 
-    def test_backfilled_redactions(self):
+    def test_backfilled_redactions(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
 
@@ -139,7 +143,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
         self.check("get_event", [msg.event_id], redacted)
 
-    def test_invites(self):
+    def test_invites(self) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
         event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
@@ -163,7 +167,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
         )
 
     @parameterized.expand([(True,), (False,)])
-    def test_push_actions_for_user(self, send_receipt: bool):
+    def test_push_actions_for_user(self, send_receipt: bool) -> None:
         self.persist(type="m.room.create", key="", creator=USER_ID)
         self.persist(type="m.room.member", key=USER_ID, membership="join")
         self.persist(
@@ -219,7 +223,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             ),
         )
 
-    def test_get_rooms_for_user_with_stream_ordering(self):
+    def test_get_rooms_for_user_with_stream_ordering(self) -> None:
         """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
         by rows in the events stream
         """
@@ -243,7 +247,9 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
         )
 
-    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
+    def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
+        self,
+    ) -> None:
         """Check that current_state invalidation happens correctly with multiple events
         in the persistence batch.
 
@@ -283,11 +289,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
             type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
         )
         msg, msgctx = self.build_event()
-        self.get_success(
-            self._storage_controllers.persistence.persist_events(
-                [(j2, j2ctx), (msg, msgctx)]
-            )
-        )
+        self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
         self.replicate()
         assert j2.internal_metadata.stream_ordering is not None
 
@@ -339,7 +341,7 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
     event_id = 0
 
-    def persist(self, backfill=False, **kwargs) -> FrozenEvent:
+    def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
         """
         Returns:
             The event that was persisted.
@@ -348,32 +350,28 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
 
         if backfill:
             self.get_success(
-                self._storage_controllers.persistence.persist_events(
-                    [(event, context)], backfilled=True
-                )
+                self.persistance.persist_events([(event, context)], backfilled=True)
             )
         else:
-            self.get_success(
-                self._storage_controllers.persistence.persist_event(event, context)
-            )
+            self.get_success(self.persistance.persist_event(event, context))
 
         return event
 
     def build_event(
         self,
-        sender=USER_ID,
-        room_id=ROOM_ID,
-        type="m.room.message",
-        key=None,
+        sender: str = USER_ID,
+        room_id: str = ROOM_ID,
+        type: str = "m.room.message",
+        key: Optional[str] = None,
         internal: Optional[dict] = None,
-        depth=None,
-        prev_events: Optional[list] = None,
-        auth_events: Optional[list] = None,
-        prev_state: Optional[list] = None,
-        redacts=None,
+        depth: Optional[int] = None,
+        prev_events: Optional[List[Tuple[str, dict]]] = None,
+        auth_events: Optional[List[str]] = None,
+        prev_state: Optional[List[str]] = None,
+        redacts: Optional[str] = None,
         push_actions: Iterable = frozenset(),
-        **content,
-    ):
+        **content: object,
+    ) -> Tuple[EventBase, EventContext]:
         prev_events = prev_events or []
         auth_events = auth_events or []
         prev_state = prev_state or []
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 50fbff5f32..01df1be047 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
 
 
 class AccountDataStreamTestCase(BaseStreamTestCase):
-    def test_update_function_room_account_data_limit(self):
+    def test_update_function_room_account_data_limit(self) -> None:
         """Test replication with many room account data updates"""
         store = self.hs.get_datastores().main
 
@@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_global_account_data_limit(self):
+    def test_update_function_global_account_data_limit(self) -> None:
         """Test replication with many global account data updates"""
         store = self.hs.get_datastores().main
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 641a94133b..65ef4bb160 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional
+from typing import Any, List, Optional, Sequence
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
 )
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseStreamTestCase
 from tests.test_utils.event_injection import inject_event, inject_member_event
@@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
         self.user_id = self.register_user("u1", "pass")
         self.user_tok = self.login("u1", "pass")
@@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
 
-    def test_update_function_event_row_limit(self):
+    def test_update_function_event_row_limit(self) -> None:
         """Test replication with many non-state events
 
         Checks that all events are correctly replicated when there are lots of
@@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_huge_state_change(self):
+    def test_update_function_huge_state_change(self) -> None:
         """Test replication with many state events
 
         Ensures that all events are correctly replicated when there are lots of
@@ -135,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -164,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         pl_event = self.get_success(
             inject_event(
                 self.hs,
-                prev_event_ids=prev_events,
+                prev_event_ids=list(prev_events),
                 type=EventTypes.PowerLevels,
                 state_key="",
                 sender=self.user_id,
@@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             # "None" indicates the state has been deleted
             self.assertIsNone(sr.event_id)
 
-    def test_update_function_state_row_limit(self):
+    def test_update_function_state_row_limit(self) -> None:
         """Test replication with many state events over several stream ids."""
 
         # we want to generate lots of state changes, but for this test, we want to
@@ -290,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -319,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             e = self.get_success(
                 inject_event(
                     self.hs,
-                    prev_event_ids=prev_events,
+                    prev_event_ids=list(prev_events),
                     type=EventTypes.PowerLevels,
                     state_key="",
                     sender=self.user_id,
@@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_backwards_stream_id(self):
+    def test_backwards_stream_id(self) -> None:
         """
         Test that RDATA that comes after the current position should be discarded.
         """
@@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
     event_count = 0
 
     def _inject_test_event(
-        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
     ) -> EventBase:
         if sender is None:
             sender = self.user_id
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index bcb82c9c80..cdbdfaf057 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
         config["federation_sender_instances"] = ["federation_sender1"]
         return config
 
-    def test_catchup(self):
+    def test_catchup(self) -> None:
         """Basic test of catchup on reconnect
 
         Makes sure that updates sent while we are offline are received later.
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 2c10eab4db..452ac85069 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
     hijack_auth = True
     user_id = "@bob:test"
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
         self.store = self.hs.get_datastores().main
 
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
         room_id = self.helper.create_room_as("@bob:test")
         # Mark the room as partial-stated.
         self.get_success(
-            self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+            self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
         )
 
         worker = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 9a229dd23f..5a38ac831f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 from unittest.mock import Mock
 
-from synapse.handlers.typing import RoomMember
+from synapse.handlers.typing import RoomMember, TypingWriterHandler
 from synapse.replication.tcp.streams import TypingStream
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -27,11 +27,13 @@ ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
-    def _build_replication_data_handler(self):
-        return Mock(wraps=super()._build_replication_data_handler())
+    def _build_replication_data_handler(self) -> Mock:
+        self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
+        return self.mock_handler
 
-    def test_typing(self):
+    def test_typing(self) -> None:
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
@@ -43,8 +45,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -54,11 +56,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
         # Now let's disconnect and insert some data.
         self.disconnect()
 
-        self.test_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.reset_mock()
 
         typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.assert_not_called()
 
         self.reconnect()
         self.pump(0.1)
@@ -71,15 +73,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         assert request.args is not None
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([], row.user_ids)
 
-    def test_reset(self):
+    def test_reset(self) -> None:
         """
         Test what happens when a typing stream resets.
 
@@ -87,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         sends the proper position and RDATA).
         """
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
@@ -98,8 +101,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -134,15 +137,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         # Reset the test code.
-        self.test_handler.on_rdata.reset_mock()
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.assert_not_called()
 
         # Push additional data.
         typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
         self.reactor.advance(0)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index cca7ebb719..5d6b72b16d 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -21,12 +21,12 @@ from tests.unittest import TestCase
 
 
 class ParseCommandTestCase(TestCase):
-    def test_parse_one_word_command(self):
+    def test_parse_one_word_command(self) -> None:
         line = "REPLICATE"
         cmd = parse_command_from_line(line)
         self.assertIsInstance(cmd, ReplicateCommand)
 
-    def test_parse_rdata(self):
+    def test_parse_rdata(self) -> None:
         line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
@@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
         self.assertEqual(cmd.instance_name, "master")
         self.assertEqual(cmd.token, 6287863)
 
-    def test_parse_rdata_batch(self):
+    def test_parse_rdata_batch(self) -> None:
         line = 'RDATA presence master batch ["@foo:example.com", "online"]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 6e4055cc21..bab77b2df7 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
 
         # ... updating the cache ID gen on the master still shouldn't cause the
         # deferred to wake up.
+        assert store._cache_id_gen is not None
         ctx = store._cache_id_gen.get_next()
         self.get_success(ctx.__aenter__())
         self.get_success(ctx.__aexit__(None, None, None))
@@ -140,3 +141,64 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
         self.get_success(ctx_worker1.__aexit__(None, None, None))
 
         self.assertTrue(d.called)
+
+    def test_wait_for_stream_position_rdata(self) -> None:
+        """Check that wait for stream position correctly waits for an update
+        from the correct instance, when RDATA is sent.
+        """
+        store = self.hs.get_datastores().main
+        cmd_handler = self.hs.get_replication_command_handler()
+        data_handler = self.hs.get_replication_data_handler()
+
+        worker1 = self.make_worker_hs(
+            "synapse.app.generic_worker",
+            extra_config={
+                "worker_name": "worker1",
+                "run_background_tasks_on": "worker1",
+                "redis": {"enabled": True},
+            },
+        )
+
+        cache_id_gen = worker1.get_datastores().main._cache_id_gen
+        assert cache_id_gen is not None
+
+        self.replicate()
+
+        # First, make sure the master knows that `worker1` exists.
+        initial_token = cache_id_gen.get_current_token()
+        cmd_handler.send_command(
+            PositionCommand("caches", "worker1", initial_token, initial_token)
+        )
+        self.replicate()
+
+        # `wait_for_stream_position` should only return once master receives a
+        # notification that `next_token2` has persisted.
+        ctx_worker1 = cache_id_gen.get_next_mult(2)
+        next_token1, next_token2 = self.get_success(ctx_worker1.__aenter__())
+
+        d = defer.ensureDeferred(
+            data_handler.wait_for_stream_position("worker1", "caches", next_token2)
+        )
+        self.assertFalse(d.called)
+
+        # Insert an entry into the cache stream with token `next_token1`, but
+        # not `next_token2`.
+        self.get_success(
+            store.db_pool.simple_insert(
+                table="cache_invalidation_stream_by_instance",
+                values={
+                    "stream_id": next_token1,
+                    "instance_name": "worker1",
+                    "cache_func": "foo",
+                    "keys": [],
+                    "invalidation_ts": 0,
+                },
+            )
+        )
+
+        # Finish the context manager, triggering the data to be sent to master.
+        self.get_success(ctx_worker1.__aexit__(None, None, None))
+
+        # Master should get told about `next_token2`, so the deferred should
+        # resolve.
+        self.assertTrue(d.called)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index 545f11acd1..b75fc05fd5 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -16,15 +16,17 @@ from typing import Tuple
 
 from twisted.internet.address import IPv4Address
 from twisted.internet.interfaces import IProtocol
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
 
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
 class RemoteServerUpTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.factory = ReplicationStreamProtocolFactory(hs)
 
     def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
 
         return proto, transport
 
-    def test_relay(self):
+    def test_relay(self) -> None:
         """Test that Synapse will relay REMOTE_SERVER_UP commands to all
         other connections, but not the one that sent it.
         """
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index 5d7a89e0c7..98602371e4 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -13,7 +13,11 @@
 # limitations under the License.
 import logging
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.rest.client import register
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import FakeChannel, make_request
@@ -27,7 +31,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
 
     servlets = [register.register_servlets]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         # This isn't a real configuration option but is used to provide the main
         # homeserver and worker homeserver different options.
@@ -77,7 +81,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
             {"auth": {"session": session, "type": "m.login.dummy"}},
         )
 
-    def test_no_auth(self):
+    def test_no_auth(self) -> None:
         """With no authentication the request should finish."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
@@ -86,7 +90,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(channel.json_body["user_id"], "@user:test")
 
     @override_config({"main_replication_secret": "my-secret"})
-    def test_missing_auth(self):
+    def test_missing_auth(self) -> None:
         """If the main process expects a secret that is not provided, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
@@ -97,13 +101,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
             "worker_replication_secret": "wrong-secret",
         }
     )
-    def test_unauthorized(self):
+    def test_unauthorized(self) -> None:
         """If the main process receives the wrong secret, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
 
     @override_config({"worker_replication_secret": "my-secret"})
-    def test_authorized(self):
+    def test_authorized(self) -> None:
         """The request should finish when the worker provides the authentication header."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index eb5b376534..eca5033761 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -33,7 +33,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         config["worker_replication_http_port"] = "8765"
         return config
 
-    def test_register_single_worker(self):
+    def test_register_single_worker(self) -> None:
         """Test that registration works when using a single generic worker."""
         worker_hs = self.make_worker_hs("synapse.app.generic_worker")
         site = self._hs_to_site[worker_hs]
@@ -63,7 +63,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
 
-    def test_register_multi_worker(self):
+    def test_register_multi_worker(self) -> None:
         """Test that registration works when using multiple generic workers."""
         worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
         worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 63b1dd40b5..12668b34c5 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -14,10 +14,14 @@
 
 from unittest import mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.app.generic_worker import GenericWorkerServer
 from synapse.replication.tcp.commands import FederationAckCommand
 from synapse.replication.tcp.protocol import IReplicationConnection
 from synapse.replication.tcp.streams.federation import FederationStream
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
@@ -30,12 +34,10 @@ class FederationAckTestCase(HomeserverTestCase):
         config["federation_sender_instances"] = ["federation_sender1"]
         return config
 
-    def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
-
-        return hs
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        return self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
 
-    def test_federation_ack_sent(self):
+    def test_federation_ack_sent(self) -> None:
         """A FEDERATION_ACK should be sent back after each RDATA federation
 
         This test checks that the federation sender is correctly sending back
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index c28073b8f7..08703206a9 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,6 +16,7 @@ from unittest.mock import Mock
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.builder import EventBuilderFactory
+from synapse.handlers.typing import TypingWriterHandler
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client import login, room
 from synapse.types import UserID, create_requester
@@ -40,7 +41,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         room.register_servlets,
     ]
 
-    def test_send_event_single_sender(self):
+    def test_send_event_single_sender(self) -> None:
         """Test that using a single federation sender worker correctly sends a
         new event.
         """
@@ -71,7 +72,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
         self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
 
-    def test_send_event_sharded(self):
+    def test_send_event_sharded(self) -> None:
         """Test that using two federation sender workers correctly sends
         new events.
         """
@@ -138,7 +139,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertTrue(sent_on_1)
         self.assertTrue(sent_on_2)
 
-    def test_send_typing_sharded(self):
+    def test_send_typing_sharded(self) -> None:
         """Test that using two federation sender workers correctly sends
         new typing EDUs.
         """
@@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         token = self.login("user3", "pass")
 
         typing_handler = self.hs.get_typing_handler()
+        assert isinstance(typing_handler, TypingWriterHandler)
 
         sent_on_1 = False
         sent_on_2 = False
@@ -215,7 +217,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertTrue(sent_on_1)
         self.assertTrue(sent_on_2)
 
-    def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+    def create_room_with_remote_server(
+        self, user: str, token: str, remote_server: str = "other_server"
+    ) -> str:
         room = self.helper.create_room_as(user, tok=token)
         store = self.hs.get_datastores().main
         federation = self.hs.get_federation_event_handler()
diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py
index b93cae67d3..9c4fbda71b 100644
--- a/tests/replication/test_module_cache_invalidation.py
+++ b/tests/replication/test_module_cache_invalidation.py
@@ -39,7 +39,7 @@ class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
         synapse.rest.admin.register_servlets,
     ]
 
-    def test_module_cache_full_invalidation(self):
+    def test_module_cache_full_invalidation(self) -> None:
         main_cache = TestCache()
         self.hs.get_module_api().register_cached_function(main_cache.cached_function)
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 96cdf2c45b..1527b4a82d 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -18,12 +18,14 @@ from typing import Optional, Tuple
 from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
 
 from synapse.rest import admin
 from synapse.rest.client import login
 from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -43,13 +45,13 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
 
         self.reactor.lookups["example.com"] = "1.2.3.4"
 
-    def default_config(self):
+    def default_config(self) -> dict:
         conf = super().default_config()
         conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
         return conf
@@ -122,7 +124,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
 
         return channel, request
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         """Test basic fetching of remote media from a single worker."""
         hs1 = self.make_worker_hs("synapse.app.generic_worker")
 
@@ -138,7 +140,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.result["body"], b"Hello!")
 
-    def test_download_simple_file_race(self):
+    def test_download_simple_file_race(self) -> None:
         """Test that fetching remote media from two different processes at the
         same time works.
         """
@@ -177,7 +179,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         # We expect only one new file to have been persisted.
         self.assertEqual(start_count + 1, self._count_remote_media())
 
-    def test_download_image_race(self):
+    def test_download_image_race(self) -> None:
         """Test that fetching remote *images* from two different processes at
         the same time works.
 
@@ -229,7 +231,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return sum(len(files) for _, _, files in os.walk(path))
 
 
-def get_connection_factory():
+def get_connection_factory() -> TestServerTLSConnectionFactory:
     # this needs to happen once, but not until we are ready to run the first test
     global test_server_connection_factory
     if test_server_connection_factory is None:
@@ -263,6 +265,6 @@ def _build_test_server(
     return server_tls_factory.buildProtocol(None)
 
 
-def _log_request(request):
+def _log_request(request: Request) -> None:
     """Implements Factory.log, which is expected by Request.finish"""
     logger.info("Completed request %s", request)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ca18ad6553..0798b021c3 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -15,9 +15,12 @@ import logging
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 
@@ -33,12 +36,12 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         login.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Register a user who sends a message that we'll get notified about
         self.other_user_id = self.register_user("otheruser", "pass")
         self.other_access_token = self.login("otheruser", "pass")
 
-    def _create_pusher_and_send_msg(self, localpart):
+    def _create_pusher_and_send_msg(self, localpart: str) -> str:
         # Create a user that will get push notifications
         user_id = self.register_user(localpart, "pass")
         access_token = self.login(localpart, "pass")
@@ -47,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         user_dict = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
+        assert user_dict is not None
         token_id = user_dict.token_id
 
         self.get_success(
@@ -79,7 +83,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
 
         return event_id
 
-    def test_send_push_single_worker(self):
+    def test_send_push_single_worker(self) -> None:
         """Test that registration works when using a pusher worker."""
         http_client_mock = Mock(spec_set=["post_json_get_json"])
         http_client_mock.post_json_get_json.side_effect = (
@@ -109,7 +113,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
             ],
         )
 
-    def test_send_push_multiple_workers(self):
+    def test_send_push_multiple_workers(self) -> None:
         """Test that registration works when using sharded pusher workers."""
         http_client_mock1 = Mock(spec_set=["post_json_get_json"])
         http_client_mock1.post_json_get_json.side_effect = (
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 541d390286..7f9cc67e73 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -14,9 +14,13 @@
 import logging
 from unittest.mock import patch
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.util import Clock
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
@@ -34,7 +38,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         sync.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Register a user who sends a message that we'll get notified about
         self.other_user_id = self.register_user("otheruser", "pass")
         self.other_access_token = self.login("otheruser", "pass")
@@ -42,7 +46,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.room_creator = self.hs.get_room_creation_handler()
         self.store = hs.get_datastores().main
 
-    def default_config(self):
+    def default_config(self) -> dict:
         conf = super().default_config()
         conf["stream_writers"] = {"events": ["worker1", "worker2"]}
         conf["instance_map"] = {
@@ -51,7 +55,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         }
         return conf
 
-    def _create_room(self, room_id: str, user_id: str, tok: str):
+    def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
         """Create a room with given room_id"""
 
         # We control the room ID generation by patching out the
@@ -62,7 +66,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             mock.side_effect = lambda: room_id
             self.helper.create_room_as(user_id, tok=tok)
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         """Simple test to ensure that multiple rooms can be created and joined,
         and that different rooms get handled by different instances.
         """
@@ -112,7 +116,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertTrue(persisted_on_1)
         self.assertTrue(persisted_on_2)
 
-    def test_vector_clock_token(self):
+    def test_vector_clock_token(self) -> None:
         """Tests that using a stream token with a vector clock component works
         correctly with basic /sync and /messages usage.
         """
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index 03f2112b07..aaa488bced 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -28,7 +28,6 @@ from tests import unittest
 
 
 class DeviceRestTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
@@ -291,7 +290,6 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
 
 
 class DevicesRestTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
@@ -415,7 +413,6 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
 
 
 class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 233eba3516..f189b07769 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -78,7 +78,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
         """
         Try to get an event report without authentication.
         """
-        channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, {})
 
         self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -473,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         """
         Try to get event report without authentication.
         """
-        channel = self.make_request("GET", self.url, b"{}")
+        channel = self.make_request("GET", self.url, {})
 
         self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@@ -599,3 +599,142 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
         self.assertIn("room_id", content["event_json"])
         self.assertIn("sender", content["event_json"])
         self.assertIn("content", content["event_json"])
+
+
+class DeleteEventReportTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self._store = hs.get_datastores().main
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        # create report
+        event_id = self.get_success(
+            self._store.add_event_report(
+                "room_id",
+                "event_id",
+                self.other_user,
+                "this makes me sad",
+                {},
+                self.clock.time_msec(),
+            )
+        )
+
+        self.url = f"/_synapse/admin/v1/event_reports/{event_id}"
+
+    def test_no_auth(self) -> None:
+        """
+        Try to delete event report without authentication.
+        """
+        channel = self.make_request("DELETE", self.url)
+
+        self.assertEqual(401, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self) -> None:
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            access_token=self.other_user_tok,
+        )
+
+        self.assertEqual(403, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_delete_success(self) -> None:
+        """
+        Testing delete a report.
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual({}, channel.json_body)
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        # check that report was deleted
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_invalid_report_id(self) -> None:
+        """
+        Testing that an invalid `report_id` returns a 400.
+        """
+
+        # `report_id` is negative
+        channel = self.make_request(
+            "DELETE",
+            "/_synapse/admin/v1/event_reports/-123",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+        self.assertEqual(
+            "The report_id parameter must be a string representing a positive integer.",
+            channel.json_body["error"],
+        )
+
+        # `report_id` is a non-numerical string
+        channel = self.make_request(
+            "DELETE",
+            "/_synapse/admin/v1/event_reports/abcdef",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+        self.assertEqual(
+            "The report_id parameter must be a string representing a positive integer.",
+            channel.json_body["error"],
+        )
+
+        # `report_id` is undefined
+        channel = self.make_request(
+            "DELETE",
+            "/_synapse/admin/v1/event_reports/",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
+        self.assertEqual(
+            "The report_id parameter must be a string representing a positive integer.",
+            channel.json_body["error"],
+        )
+
+    def test_report_id_not_found(self) -> None:
+        """
+        Testing that a not existing `report_id` returns a 404.
+        """
+
+        channel = self.make_request(
+            "DELETE",
+            "/_synapse/admin/v1/event_reports/123",
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+        self.assertEqual("Event report not found", channel.json_body["error"])
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..6d04911d67 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -20,8 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.errors import Codes
+from synapse.media.filepath import MediaFilePaths
 from synapse.rest.client import login, profile, room
-from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -34,7 +34,6 @@ INVALID_TIMESTAMP_IN_S = 1893456000  # 2030-01-01 in seconds
 
 
 class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
@@ -196,7 +195,6 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
 
 
 class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
@@ -213,7 +211,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.admin_user_tok = self.login("admin", "pass")
 
         self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
-        self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+        self.url = "/_synapse/admin/v1/media/delete"
+        self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
 
         # Move clock up to somewhat realistic time
         self.reactor.advance(1000000000)
@@ -332,11 +331,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_delete_media_never_accessed(self) -> None:
+    @parameterized.expand([(True,), (False,)])
+    def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None:
         """
         Tests that media deleted if it is older than `before_ts` and never accessed
         `last_access_ts` is `NULL` and `created_ts` < `before_ts`
         """
+        url = self.legacy_url if use_legacy_url else self.url
 
         # upload and do not access
         server_and_media_id = self._create_media()
@@ -351,7 +352,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         now_ms = self.clock.time_msec()
         channel = self.make_request(
             "POST",
-            self.url + "?before_ts=" + str(now_ms),
+            url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -591,7 +592,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
 
 
 class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
@@ -721,7 +721,6 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
 
 class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
@@ -818,7 +817,6 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
 
 
 class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets_for_media_repo,
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 453a6e979c..9dbb778679 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1990,7 +1990,6 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
 
 
 class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         room.register_servlets,
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
index a2f347f666..28b999573e 100644
--- a/tests/rest/admin/test_server_notice.py
+++ b/tests/rest/admin/test_server_notice.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List
+from typing import List, Sequence
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -28,7 +28,6 @@ from tests.unittest import override_config
 
 
 class ServerNoticeTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
@@ -558,7 +557,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
 
     def _check_invite_and_join_status(
         self, user_id: str, expected_invites: int, expected_memberships: int
-    ) -> List[RoomsForUser]:
+    ) -> Sequence[RoomsForUser]:
         """Check invite and room membership status of a user.
 
         Args
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..4b8f889a71 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -28,8 +28,8 @@ import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
 from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
 from synapse.api.room_versions import RoomVersions
+from synapse.media.filepath import MediaFilePaths
 from synapse.rest.client import devices, login, logout, profile, register, room, sync
-from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         other_user_tok = self.login("user", "pass")
         event_builder_factory = self.hs.get_event_builder_factory()
         event_creation_handler = self.hs.get_event_creation_handler()
-        storage_controllers = self.hs.get_storage_controllers()
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
 
         # Create two rooms, one with a local user only and one with both a local
         # and remote user.
@@ -2934,11 +2935,13 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(storage_controllers.persistence.persist_event(event, context))
+        context = self.get_success(unpersisted_context.persist(event))
+
+        self.get_success(persistence.persist_event(event, context))
 
         # Now get rooms
         url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py
index 30f12f1bff..6c04e6c56c 100644
--- a/tests/rest/admin/test_username_available.py
+++ b/tests/rest/admin/test_username_available.py
@@ -11,6 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Optional
+
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
@@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
         self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
 
-        async def check_username(username: str) -> bool:
-            if username == "allowed":
-                return True
+        async def check_username(
+            localpart: str,
+            guest_access_token: Optional[str] = None,
+            assigned_user_id: Optional[str] = None,
+            inhibit_user_in_use_error: bool = False,
+        ) -> None:
+            if localpart == "allowed":
+                return
             raise SynapseError(
                 400,
                 "User ID already taken.",
@@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
             )
 
         handler = self.hs.get_registration_handler()
-        handler.check_username = check_username
+        handler.check_username = check_username  # type: ignore[assignment]
 
     def test_username_available(self) -> None:
         """
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 88f255c9ee..2b05dffc7d 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
 
 
 class PasswordResetTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         account.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -408,7 +407,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
 
 class DeactivateTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -492,7 +490,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
 
 
 class WhoamiTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -567,7 +564,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
 
 
 class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         account.register_servlets,
         login.register_servlets,
@@ -1193,7 +1189,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
                 return {}
 
         # Register a mock that will return the expected result depending on the remote.
-        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
+        self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)  # type: ignore[assignment]
 
         # Check that we've got the correct response from the client-side endpoint.
         self._test_status(
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 208ec44829..0d8fe77b88 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -34,7 +34,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
-from tests.server import FakeChannel, make_request
+from tests.server import FakeChannel
 from tests.unittest import override_config, skip_unless
 
 
@@ -43,13 +43,15 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
         super().__init__(hs)
         self.recaptcha_attempts: List[Tuple[dict, str]] = []
 
+    def is_enabled(self) -> bool:
+        return True
+
     def check_auth(self, authdict: dict, clientip: str) -> Any:
         self.recaptcha_attempts.append((authdict, clientip))
         return succeed(True)
 
 
 class FallbackAuthTests(unittest.HomeserverTestCase):
-
     servlets = [
         auth.register_servlets,
         register.register_servlets,
@@ -57,7 +59,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
     hijack_auth = False
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
 
         config["enable_registration_captcha"] = True
@@ -1319,16 +1320,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
         channel = self.submit_logout_token(logout_token)
         self.assertEqual(channel.code, 200)
 
-        # Now try to exchange the login token
-        channel = make_request(
-            self.hs.get_reactor(),
-            self.site,
-            "POST",
-            "/login",
-            content={"type": "m.login.token", "token": login_token},
-        )
-        # It should have failed
-        self.assertEqual(channel.code, 403)
+        # Now try to exchange the login token, it should fail.
+        self.helper.login_via_token(login_token, 403)
 
     @override_config(
         {
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index d1751e1557..c16e8d43f4 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -26,7 +26,6 @@ from tests.unittest import override_config
 
 
 class CapabilitiesTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         capabilities.register_servlets,
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index b1ca81a911..bb845179d3 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -38,7 +38,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
     hijack_auth = False
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["form_secret"] = "123abc"
 
diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py
index 7a88aa2cda..6490e883bf 100644
--- a/tests/rest/client/test_directory.py
+++ b/tests/rest/client/test_directory.py
@@ -28,7 +28,6 @@ from tests.unittest import override_config
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         directory.register_servlets,
diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py
index 9fa1f82dfe..f31ebc8021 100644
--- a/tests/rest/client/test_ephemeral_message.py
+++ b/tests/rest/client/test_ephemeral_message.py
@@ -26,7 +26,6 @@ from tests import unittest
 
 
 class EphemeralMessageTestCase(unittest.HomeserverTestCase):
-
     user_id = "@user:test"
 
     servlets = [
diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py
index a9b7db9db2..54df2a252c 100644
--- a/tests/rest/client/test_events.py
+++ b/tests/rest/client/test_events.py
@@ -38,7 +38,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["enable_registration_captcha"] = False
         config["enable_registration"] = True
@@ -51,7 +50,6 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         # register an account
         self.user_id = self.register_user("sid1", "pass")
         self.token = self.login(self.user_id, "pass")
@@ -142,7 +140,6 @@ class GetEventsTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         # register an account
         self.user_id = self.register_user("sid1", "pass")
         self.token = self.login(self.user_id, "pass")
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index afc8d641be..91678abf13 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -25,7 +25,6 @@ PATH_PREFIX = "/_matrix/client/v2_alpha"
 
 
 class FilterTestCase(unittest.HomeserverTestCase):
-
     user_id = "@apple:test"
     hijack_auth = True
     EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
@@ -63,14 +62,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     def test_add_filter_non_local_user(self) -> None:
         _is_mine = self.hs.is_mine
-        self.hs.is_mine = lambda target_user: False
+        self.hs.is_mine = lambda target_user: False  # type: ignore[assignment]
         channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
             self.EXAMPLE_FILTER_JSON,
         )
 
-        self.hs.is_mine = _is_mine
+        self.hs.is_mine = _is_mine  # type: ignore[assignment]
         self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index 741fecea77..8ee5489057 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -14,12 +14,21 @@
 
 from http import HTTPStatus
 
+from signedjson.key import (
+    encode_verify_key_base64,
+    generate_signing_key,
+    get_verify_key,
+)
+from signedjson.sign import sign_json
+
 from synapse.api.errors import Codes
 from synapse.rest import admin
 from synapse.rest.client import keys, login
+from synapse.types import JsonDict
 
 from tests import unittest
 from tests.http.server._base import make_request_with_cancellation_test
+from tests.unittest import override_config
 
 
 class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -118,3 +127,135 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertIn(bob, channel.json_body["device_keys"])
+
+    def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
+        # We only generate a master key to simplify the test.
+        master_signing_key = generate_signing_key(device_id)
+        master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
+
+        return {
+            "master_key": sign_json(
+                {
+                    "user_id": user_id,
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + master_verify_key: master_verify_key},
+                },
+                user_id,
+                master_signing_key,
+            ),
+        }
+
+    def test_device_signing_with_uia(self) -> None:
+        """Device signing key upload requires UIA."""
+        password = "wonderland"
+        device_id = "ABCDEFGHI"
+        alice_id = self.register_user("alice", password)
+        alice_token = self.login("alice", password, device_id=device_id)
+
+        content = self.make_device_keys(alice_id, device_id)
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            content,
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        # Grab the session
+        session = channel.json_body["session"]
+        # Ensure that flows are what is expected.
+        self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+        # add UI auth
+        content["auth"] = {
+            "type": "m.login.password",
+            "identifier": {"type": "m.id.user", "user": alice_id},
+            "password": password,
+            "session": session,
+        }
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            content,
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    @override_config({"ui_auth": {"session_timeout": "15m"}})
+    def test_device_signing_with_uia_session_timeout(self) -> None:
+        """Device signing key upload requires UIA buy passes with grace period."""
+        password = "wonderland"
+        device_id = "ABCDEFGHI"
+        alice_id = self.register_user("alice", password)
+        alice_token = self.login("alice", password, device_id=device_id)
+
+        content = self.make_device_keys(alice_id, device_id)
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            content,
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+    @override_config(
+        {
+            "experimental_features": {"msc3967_enabled": True},
+            "ui_auth": {"session_timeout": "15s"},
+        }
+    )
+    def test_device_signing_with_msc3967(self) -> None:
+        """Device signing key follows MSC3967 behaviour when enabled."""
+        password = "wonderland"
+        device_id = "ABCDEFGHI"
+        alice_id = self.register_user("alice", password)
+        alice_token = self.login("alice", password, device_id=device_id)
+
+        keys1 = self.make_device_keys(alice_id, device_id)
+
+        # Initial request should succeed as no existing keys are present.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys1,
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+        keys2 = self.make_device_keys(alice_id, device_id)
+
+        # Subsequent request should require UIA as keys already exist even though session_timeout is set.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys2,
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+
+        # Grab the session
+        session = channel.json_body["session"]
+        # Ensure that flows are what is expected.
+        self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+
+        # add UI auth
+        keys2["auth"] = {
+            "type": "m.login.password",
+            "identifier": {"type": "m.id.user", "user": alice_id},
+            "password": password,
+            "session": session,
+        }
+
+        # Request should complete
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/keys/device_signing/upload",
+            keys2,
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index ff5baa9f0a..62acf4f44e 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -89,7 +89,6 @@ ADDITIONAL_LOGIN_FLOWS = [
 
 
 class LoginRestServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -737,7 +736,6 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
 
 class CASTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         login.register_servlets,
     ]
diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py
index 6aedc1a11c..b8187db982 100644
--- a/tests/rest/client/test_login_token_request.py
+++ b/tests/rest/client/test_login_token_request.py
@@ -26,7 +26,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token"
 
 
 class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         login.register_servlets,
         admin.register_servlets,
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index b3738a0304..dcbb125a3b 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -35,15 +35,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
     servlets = [presence.register_servlets]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
-        presence_handler = Mock(spec=PresenceHandler)
-        presence_handler.set_state.return_value = make_awaitable(None)
+        self.presence_handler = Mock(spec=PresenceHandler)
+        self.presence_handler.set_state.return_value = make_awaitable(None)
 
         hs = self.setup_test_homeserver(
             "red",
             federation_http_client=None,
             federation_client=Mock(),
-            presence_handler=presence_handler,
+            presence_handler=self.presence_handler,
         )
 
         return hs
@@ -61,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(channel.code, HTTPStatus.OK)
-        self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
+        self.assertEqual(self.presence_handler.set_state.call_count, 1)
 
     @unittest.override_config({"use_presence": False})
     def test_put_presence_disabled(self) -> None:
@@ -76,4 +75,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(channel.code, HTTPStatus.OK)
-        self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
+        self.assertEqual(self.presence_handler.set_state.call_count, 0)
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 8de5a342ae..27c93ad761 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -30,7 +30,6 @@ from tests import unittest
 
 
 class ProfileTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -324,7 +323,6 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
 
 class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
@@ -404,7 +402,6 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
 
 class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 11cf3939d8..b228dba861 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -40,7 +40,6 @@ from tests.unittest import override_config
 
 
 class RegisterRestServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         login.register_servlets,
         register.register_servlets,
@@ -151,7 +150,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
 
     def test_POST_guest_registration(self) -> None:
-        self.hs.config.key.macaroon_secret_key = "test"
+        self.hs.config.key.macaroon_secret_key = b"test"
         self.hs.config.registration.allow_guest_access = True
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
@@ -797,7 +796,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
 
 class AccountValidityTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         register.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -913,7 +911,6 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
 
 
 class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         register.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -1132,7 +1129,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
 
 
 class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
-
     servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@@ -1166,12 +1162,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
         """
         user_id = self.register_user("kermit_delta", "user")
 
-        self.hs.config.account_validity.startup_job_max_delta = self.max_delta
+        self.hs.config.account_validity.account_validity_startup_job_max_delta = (
+            self.max_delta
+        )
 
         now_ms = self.hs.get_clock().time_msec()
         self.get_success(self.store._set_expiration_date_when_missing())
 
         res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
+        assert res is not None
 
         self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
         self.assertLessEqual(res, now_ms + self.validity_period)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c8a6911d5e..fbbbcb23f1 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -30,7 +30,6 @@ from tests import unittest
 from tests.server import FakeChannel
 from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import inject_event
-from tests.unittest import override_config
 
 
 class BaseRelationsTestCase(unittest.HomeserverTestCase):
@@ -403,7 +402,7 @@ class RelationsTestCase(BaseRelationsTestCase):
 
     def test_edit(self) -> None:
         """Test that a simple edit works."""
-
+        orig_body = {"body": "Hi!", "msgtype": "m.text"}
         new_body = {"msgtype": "m.text", "body": "I've been edited!"}
         edit_event_content = {
             "msgtype": "m.text",
@@ -424,9 +423,7 @@ class RelationsTestCase(BaseRelationsTestCase):
             access_token=self.user_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
-        )
+        self.assertEqual(channel.json_body["content"], orig_body)
         self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
 
         # Request the room messages.
@@ -443,7 +440,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
 
         # Request the room context.
-        # /context should return the edited event.
+        # /context should return the event.
         channel = self.make_request(
             "GET",
             f"/rooms/{self.room}/context/{self.parent_id}",
@@ -453,7 +450,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         self._assert_edit_bundle(
             channel.json_body["event"], edit_event_id, edit_event_content
         )
-        self.assertEqual(channel.json_body["event"]["content"], new_body)
+        self.assertEqual(channel.json_body["event"]["content"], orig_body)
 
         # Request sync, but limit the timeline so it becomes limited (and includes
         # bundled aggregations).
@@ -491,45 +488,11 @@ class RelationsTestCase(BaseRelationsTestCase):
             edit_event_content,
         )
 
-    @override_config({"experimental_features": {"msc3925_inhibit_edit": True}})
-    def test_edit_inhibit_replace(self) -> None:
-        """
-        If msc3925_inhibit_edit is enabled, then the original event should not be
-        replaced.
-        """
-
-        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
-        edit_event_content = {
-            "msgtype": "m.text",
-            "body": "foo",
-            "m.new_content": new_body,
-        }
-        channel = self._send_relation(
-            RelationTypes.REPLACE,
-            "m.room.message",
-            content=edit_event_content,
-        )
-        edit_event_id = channel.json_body["event_id"]
-
-        # /context should return the *original* event.
-        channel = self.make_request(
-            "GET",
-            f"/rooms/{self.room}/context/{self.parent_id}",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["event"]["content"], {"body": "Hi!", "msgtype": "m.text"}
-        )
-        self._assert_edit_bundle(
-            channel.json_body["event"], edit_event_id, edit_event_content
-        )
-
     def test_multi_edit(self) -> None:
         """Test that multiple edits, including attempts by people who
         shouldn't be allowed, are correctly handled.
         """
-
+        orig_body = orig_body = {"body": "Hi!", "msgtype": "m.text"}
         self._send_relation(
             RelationTypes.REPLACE,
             "m.room.message",
@@ -570,7 +533,7 @@ class RelationsTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(200, channel.code, channel.json_body)
 
-        self.assertEqual(channel.json_body["event"]["content"], new_body)
+        self.assertEqual(channel.json_body["event"]["content"], orig_body)
         self._assert_edit_bundle(
             channel.json_body["event"], edit_event_id, edit_event_content
         )
@@ -642,6 +605,7 @@ class RelationsTestCase(BaseRelationsTestCase):
 
     def test_edit_edit(self) -> None:
         """Test that an edit cannot be edited."""
+        orig_body = {"body": "Hi!", "msgtype": "m.text"}
         new_body = {"msgtype": "m.text", "body": "Initial edit"}
         edit_event_content = {
             "msgtype": "m.text",
@@ -675,14 +639,12 @@ class RelationsTestCase(BaseRelationsTestCase):
             access_token=self.user_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(
-            channel.json_body["content"], {"body": "Hi!", "msgtype": "m.text"}
-        )
+        self.assertEqual(channel.json_body["content"], orig_body)
 
         # The relations information should not include the edit to the edit.
         self._assert_edit_bundle(channel.json_body, edit_event_id, edit_event_content)
 
-        # /context should return the event updated for the *first* edit
+        # /context should return the bundled edit for the *first* edit
         # (The edit to the edit should be ignored.)
         channel = self.make_request(
             "GET",
@@ -690,7 +652,7 @@ class RelationsTestCase(BaseRelationsTestCase):
             access_token=self.user_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
-        self.assertEqual(channel.json_body["event"]["content"], new_body)
+        self.assertEqual(channel.json_body["event"]["content"], orig_body)
         self._assert_edit_bundle(
             channel.json_body["event"], edit_event_id, edit_event_content
         )
@@ -1080,48 +1042,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         ]
         assert_bundle(self._find_event_in_chunk(chunk))
 
-    def test_annotation(self) -> None:
-        """
-        Test that annotations get correctly bundled.
-        """
-        # Setup by sending a variety of relations.
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
-        def assert_annotations(bundled_aggregations: JsonDict) -> None:
-            self.assertEqual(
-                {
-                    "chunk": [
-                        {"type": "m.reaction", "key": "a", "count": 2},
-                        {"type": "m.reaction", "key": "b", "count": 1},
-                    ]
-                },
-                bundled_aggregations,
-            )
-
-        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
-
-    def test_annotation_to_annotation(self) -> None:
-        """Any relation to an annotation should be ignored."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        event_id = channel.json_body["event_id"]
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id
-        )
-
-        # Fetch the initial annotation event to see if it has bundled aggregations.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}",
-            access_token=self.user_token,
-        )
-        self.assertEquals(200, channel.code, channel.json_body)
-        # The first annotationt should not have any bundled aggregations.
-        self.assertNotIn("m.relations", channel.json_body["unsigned"])
-
     def test_reference(self) -> None:
         """
         Test that references get correctly bundled.
@@ -1138,7 +1058,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations,
             )
 
-        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
 
     def test_thread(self) -> None:
         """
@@ -1183,7 +1103,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # The "user" sent the root event and is making queries for the bundled
         # aggregations: they have participated.
-        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
+        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 6)
         # The "user2" sent replies in the thread and is making queries for the
         # bundled aggregations: they have participated.
         #
@@ -1208,9 +1128,10 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
         thread_2 = channel.json_body["event_id"]
 
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_2
         )
+        reference_event_id = channel.json_body["event_id"]
 
         def assert_thread(bundled_aggregations: JsonDict) -> None:
             self.assertEqual(2, bundled_aggregations.get("count"))
@@ -1235,17 +1156,15 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
             self.assert_dict(
                 {
                     "m.relations": {
-                        RelationTypes.ANNOTATION: {
-                            "chunk": [
-                                {"type": "m.reaction", "key": "a", "count": 1},
-                            ]
+                        RelationTypes.REFERENCE: {
+                            "chunk": [{"event_id": reference_event_id}]
                         },
                     }
                 },
                 bundled_aggregations["latest_event"].get("unsigned"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 6)
 
     def test_nested_thread(self) -> None:
         """
@@ -1330,7 +1249,6 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         thread_summary = relations_dict[RelationTypes.THREAD]
         self.assertIn("latest_event", thread_summary)
         latest_event_in_thread = thread_summary["latest_event"]
-        self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
         # The latest event in the thread should have the edit appear under the
         # bundled aggregations.
         self.assertDictContainsSubset(
@@ -1363,10 +1281,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
         thread_id = channel.json_body["event_id"]
 
-        # Annotate the thread.
-        self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+        # Make a reference to the thread.
+        channel = self._send_relation(
+            RelationTypes.REFERENCE, "org.matrix.test", parent_id=thread_id
         )
+        reference_event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
             "GET",
@@ -1377,9 +1296,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         self.assertEqual(
             channel.json_body["unsigned"].get("m.relations"),
             {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
+                RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
             },
         )
 
@@ -1396,9 +1313,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         self.assertEqual(
             thread_message["unsigned"].get("m.relations"),
             {
-                RelationTypes.ANNOTATION: {
-                    "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
-                },
+                RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
             },
         )
 
@@ -1410,7 +1325,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         Note that the spec allows for a server to return additional fields beyond
         what is specified.
         """
-        self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
+        reference_event_id = channel.json_body["event_id"]
 
         # Note that the sync filter does not include "unsigned" as a field.
         filter = urllib.parse.quote_plus(
@@ -1428,7 +1344,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # Ensure there's bundled aggregations on it.
         self.assertIn("unsigned", parent_event)
-        self.assertIn("m.relations", parent_event["unsigned"])
+        self.assertEqual(
+            parent_event["unsigned"].get("m.relations"),
+            {
+                RelationTypes.REFERENCE: {"chunk": [{"event_id": reference_event_id}]},
+            },
+        )
 
 
 class RelationIgnoredUserTestCase(BaseRelationsTestCase):
@@ -1475,53 +1396,8 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
 
         return before_aggregations[relation_type], after_aggregations[relation_type]
 
-    def test_annotation(self) -> None:
-        """Annotations should ignore"""
-        # Send 2 from us, 2 from the to be ignored user.
-        allowed_event_ids = []
-        ignored_event_ids = []
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
-        allowed_event_ids.append(channel.json_body["event_id"])
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="b")
-        allowed_event_ids.append(channel.json_body["event_id"])
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION,
-            "m.reaction",
-            key="a",
-            access_token=self.user2_token,
-        )
-        ignored_event_ids.append(channel.json_body["event_id"])
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION,
-            "m.reaction",
-            key="c",
-            access_token=self.user2_token,
-        )
-        ignored_event_ids.append(channel.json_body["event_id"])
-
-        before_aggregations, after_aggregations = self._test_ignored_user(
-            RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids
-        )
-
-        self.assertCountEqual(
-            before_aggregations["chunk"],
-            [
-                {"type": "m.reaction", "key": "a", "count": 2},
-                {"type": "m.reaction", "key": "b", "count": 1},
-                {"type": "m.reaction", "key": "c", "count": 1},
-            ],
-        )
-
-        self.assertCountEqual(
-            after_aggregations["chunk"],
-            [
-                {"type": "m.reaction", "key": "a", "count": 1},
-                {"type": "m.reaction", "key": "b", "count": 1},
-            ],
-        )
-
     def test_reference(self) -> None:
-        """Annotations should ignore"""
+        """Aggregations should exclude reference relations from ignored users"""
         channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
         allowed_event_ids = [channel.json_body["event_id"]]
 
@@ -1544,7 +1420,7 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
         )
 
     def test_thread(self) -> None:
-        """Annotations should ignore"""
+        """Aggregations should exclude thread releations from ignored users"""
         channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
         allowed_event_ids = [channel.json_body["event_id"]]
 
@@ -1618,43 +1494,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             for t in threads
         ]
 
-    def test_redact_relation_annotation(self) -> None:
-        """
-        Test that annotations of an event are properly handled after the
-        annotation is redacted.
-
-        The redacted relation should not be included in bundled aggregations or
-        the response to relations.
-        """
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-        to_redact_event_id = channel.json_body["event_id"]
-
-        channel = self._send_relation(
-            RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
-        )
-        unredacted_event_id = channel.json_body["event_id"]
-
-        # Both relations should exist.
-        event_ids = self._get_related_events()
-        relations = self._get_bundled_aggregations()
-        self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
-        self.assertEquals(
-            relations["m.annotation"],
-            {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
-        )
-
-        # Redact one of the reactions.
-        self._redact(to_redact_event_id)
-
-        # The unredacted relation should still exist.
-        event_ids = self._get_related_events()
-        relations = self._get_bundled_aggregations()
-        self.assertEquals(event_ids, [unredacted_event_id])
-        self.assertEquals(
-            relations["m.annotation"],
-            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
-        )
-
     def test_redact_relation_thread(self) -> None:
         """
         Test that thread replies are properly handled after the thread reply redacted.
@@ -1775,14 +1614,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         is redacted.
         """
         # Add a relation
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+        channel = self._send_relation(RelationTypes.REFERENCE, "org.matrix.test")
         related_event_id = channel.json_body["event_id"]
 
         # The relations should exist.
         event_ids = self._get_related_events()
         relations = self._get_bundled_aggregations()
         self.assertEqual(len(event_ids), 1)
-        self.assertIn(RelationTypes.ANNOTATION, relations)
+        self.assertIn(RelationTypes.REFERENCE, relations)
 
         # Redact the original event.
         self._redact(self.parent_id)
@@ -1792,8 +1631,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         relations = self._get_bundled_aggregations()
         self.assertEquals(event_ids, [related_event_id])
         self.assertEquals(
-            relations["m.annotation"],
-            {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
+            relations[RelationTypes.REFERENCE],
+            {"chunk": [{"event_id": related_event_id}]},
         )
 
     def test_redact_parent_thread(self) -> None:
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index c0eb5d01a6..8dbd64be55 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -25,7 +25,6 @@ endpoint = "/_matrix/client/unstable/org.matrix.msc3886/rendezvous"
 
 
 class RendezvousServletTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         rendezvous.register_servlets,
     ]
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index 7cb1017a4a..1250685d39 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -73,6 +73,18 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
         data = {"reason": None, "score": None}
         self._assert_status(400, data)
 
+    def test_cannot_report_nonexistent_event(self) -> None:
+        """
+        Tests that we don't accept event reports for events which do not exist.
+        """
+        channel = self.make_request(
+            "POST",
+            f"rooms/{self.room_id}/report/$nonsenseeventid:test",
+            {"reason": "i am very sad"},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(404, channel.code, msg=channel.result["body"])
+
     def _assert_status(self, response_status: int, data: JsonDict) -> None:
         channel = self.make_request(
             "POST", self.report_path, data, access_token=self.other_user_tok
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 9c8c1889d3..d3e06bf6b3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # Send a first event, which should be filtered out at the end of the test.
         resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
         first_event_id = resp.get("event_id")
+        assert isinstance(first_event_id, str)
 
         # Advance the time by 2 days. We're using the default retention policy, therefore
         # after this the first event will still be valid.
@@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # Send another event, which shouldn't get filtered out.
         resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
         valid_event_id = resp.get("event_id")
+        assert isinstance(valid_event_id, str)
 
         # Advance the time by another 2 days. After this, the first event should be
         # outdated but not the second one.
@@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         # Check that we can still access state events that were sent before the event that
         # has been purged.
-        self.get_event(room_id, create_event.event_id)
+        self.get_event(room_id, bool(create_event))
 
     def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
         event = self.get_success(self.store.get_event(event_id, allow_none=True))
@@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
             self.assertIsNone(event)
             return {}
 
-        self.assertIsNotNone(event)
+        assert event is not None
 
         time_now = self.clock.time_msec()
         serialized = self.serializer.serialize_event(event, time_now)
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 9222cab198..a4900703c4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -65,7 +65,6 @@ class RoomBase(unittest.HomeserverTestCase):
     servlets = [room.register_servlets, room.register_deprecated_servlets]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         self.hs = self.setup_test_homeserver(
             "red",
             federation_http_client=None,
@@ -92,7 +91,6 @@ class RoomPermissionsTestCase(RoomBase):
     rmcreator_id = "@notme:red"
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         self.helper.auth_user_id = self.rmcreator_id
         # create some rooms under the name rmcreator_id
         self.uncreated_rmid = "!aa:test"
@@ -715,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(33, channel.resource_usage.db_txn_count)
+        self.assertEqual(30, channel.resource_usage.db_txn_count)
 
     def test_post_room_initial_state(self) -> None:
         # POST with initial_state config key, expect new room id
@@ -728,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(36, channel.resource_usage.db_txn_count)
+        self.assertEqual(32, channel.resource_usage.db_txn_count)
 
     def test_post_room_visibility_key(self) -> None:
         # POST with visibility config key, expect new room id
@@ -1127,7 +1125,6 @@ class RoomInviteRatelimitTestCase(RoomBase):
 
 
 class RoomJoinTestCase(RoomBase):
-
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -2102,7 +2099,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
     hijack_auth = False
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         # Register the user who does the searching
         self.user_id2 = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
@@ -2195,7 +2191,6 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
 
 
 class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -2203,7 +2198,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         self.url = b"/_matrix/client/r0/publicRooms"
 
         config = self.default_config()
@@ -2225,7 +2219,6 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
 
 
 class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -2233,7 +2226,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["allow_public_rooms_without_auth"] = True
         self.hs = self.setup_test_homeserver(config=config)
@@ -2414,7 +2406,6 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
 
 
 class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -2983,7 +2974,6 @@ class RelationsTestCase(PaginationTestCase):
 
 
 class ContextTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -3359,7 +3349,6 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
 
 
 class ThreepidInviteTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -3382,8 +3371,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
         make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
-        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
-        self.hs.get_identity_handler().lookup_3pid = Mock(
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
+        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None),
         )
 
@@ -3438,13 +3427,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         """
         Test allowing/blocking threepid invites with a spam-check module.
 
-        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
+        In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
+        """
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
         make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
-        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
-        self.hs.get_identity_handler().lookup_3pid = Mock(
+        self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
+        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None),
         )
 
@@ -3563,8 +3553,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
         )
 
         event.internal_metadata.outlier = True
+        persistence = self._storage_controllers.persistence
+        assert persistence is not None
         self.get_success(
-            self._storage_controllers.persistence.persist_event(
+            persistence.persist_event(
                 event, EventContext.for_outlier(self._storage_controllers)
             )
         )
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index c807a37bc2..8d2cdf8751 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
     def test_invite_3pid(self) -> None:
         """Ensure that a 3PID invite does not attempt to contact the identity server."""
         identity_handler = self.hs.get_identity_handler()
-        identity_handler.lookup_3pid = Mock(
+        identity_handler.lookup_3pid = Mock(  # type: ignore[assignment]
             side_effect=AssertionError("This should not get called")
         )
 
@@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
             event_source.get_new_events(
                 user=UserID.from_string(self.other_user_id),
                 from_key=0,
-                limit=None,
+                limit=10,
                 room_ids=[room_id],
                 is_guest=False,
             )
@@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase):
                 self.banned_user_id,
             )
         )
+        assert event is not None
         self.assertEqual(
             event.content, {"membership": "join", "displayname": original_display_name}
         )
@@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase):
                 self.banned_user_id,
             )
         )
+        assert event is not None
         self.assertEqual(
             event.content, {"membership": "join", "displayname": original_display_name}
         )
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b9047194dd..9c876c7a32 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -41,7 +41,6 @@ from tests.server import TimedOutException
 
 
 class FilterTestCase(unittest.HomeserverTestCase):
-
     user_id = "@apple:test"
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -191,7 +190,6 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
 
 
 class SyncTypingTests(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -892,7 +890,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
 
 
 class ExcludeRoomTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 3325d43a2f..3b99513707 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -137,6 +137,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """Tests that a forbidden event is forbidden from being sent, but an allowed one
         can be sent.
         """
+
         # patch the rules module with a Mock which will return False for some event
         # types
         async def check(
@@ -243,6 +244,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
     def test_modify_event(self) -> None:
         """The module can return a modified version of the event"""
+
         # first patch the event checker so that it will modify the event
         async def check(
             ev: EventBase, state: StateMap[EventBase]
@@ -275,6 +277,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
     def test_message_edit(self) -> None:
         """Ensure that the module doesn't cause issues with edited messages."""
+
         # first patch the event checker so that it will modify the event
         async def check(
             ev: EventBase, state: StateMap[EventBase]
@@ -425,7 +428,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         async def test_fn(
             event: EventBase, state_events: StateMap[EventBase]
         ) -> Tuple[bool, Optional[JsonDict]]:
-            if event.is_state and event.type == EventTypes.PowerLevels:
+            if event.is_state() and event.type == EventTypes.PowerLevels:
                 await api.create_and_send_event_into_room(
                     {
                         "room_id": event.room_id,
@@ -931,3 +934,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
         # Check that the mock was called with the right parameters
         self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+    def test_on_add_and_remove_user_third_party_identifier(self) -> None:
+        """Tests that the on_add_user_third_party_identifier and
+        on_remove_user_third_party_identifier module callbacks are called
+        just before associating and removing a 3PID to/from an account.
+        """
+        # Pretend to be a Synapse module and register both callbacks as mocks.
+        third_party_rules = self.hs.get_third_party_event_rules()
+        on_add_user_third_party_identifier_callback_mock = Mock(
+            return_value=make_awaitable(None)
+        )
+        on_remove_user_third_party_identifier_callback_mock = Mock(
+            return_value=make_awaitable(None)
+        )
+        third_party_rules._on_threepid_bind_callbacks.append(
+            on_add_user_third_party_identifier_callback_mock
+        )
+        third_party_rules._on_threepid_bind_callbacks.append(
+            on_remove_user_third_party_identifier_callback_mock
+        )
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Also register a normal user we can modify.
+        user_id = self.register_user("user", "password")
+
+        # Add a 3PID to the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "threepids": [
+                    {
+                        "medium": "email",
+                        "address": "foo@example.com",
+                    },
+                ],
+            },
+            access_token=admin_tok,
+        )
+
+        # Check that the mocked add callback was called with the appropriate
+        # 3PID details.
+        self.assertEqual(channel.code, 200, channel.json_body)
+        on_add_user_third_party_identifier_callback_mock.assert_called_once()
+        args = on_add_user_third_party_identifier_callback_mock.call_args[0]
+        self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+        # Now remove the 3PID from the user
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "threepids": [],
+            },
+            access_token=admin_tok,
+        )
+
+        # Check that the mocked remove callback was called with the appropriate
+        # 3PID details.
+        self.assertEqual(channel.code, 200, channel.json_body)
+        on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+        args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+        self.assertEqual(args, (user_id, "email", "foo@example.com"))
+
+    def test_on_remove_user_third_party_identifier_is_called_on_deactivate(
+        self,
+    ) -> None:
+        """Tests that the on_remove_user_third_party_identifier module callback is called
+        when a user is deactivated and their third-party ID associations are deleted.
+        """
+        # Pretend to be a Synapse module and register both callbacks as mocks.
+        third_party_rules = self.hs.get_third_party_event_rules()
+        on_remove_user_third_party_identifier_callback_mock = Mock(
+            return_value=make_awaitable(None)
+        )
+        third_party_rules._on_threepid_bind_callbacks.append(
+            on_remove_user_third_party_identifier_callback_mock
+        )
+
+        # Register an admin user.
+        self.register_user("admin", "password", admin=True)
+        admin_tok = self.login("admin", "password")
+
+        # Also register a normal user we can modify.
+        user_id = self.register_user("user", "password")
+
+        # Add a 3PID to the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "threepids": [
+                    {
+                        "medium": "email",
+                        "address": "foo@example.com",
+                    },
+                ],
+            },
+            access_token=admin_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Now deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % user_id,
+            {
+                "deactivated": True,
+            },
+            access_token=admin_tok,
+        )
+
+        # Check that the mocked remove callback was called with the appropriate
+        # 3PID details.
+        self.assertEqual(channel.code, 200, channel.json_body)
+        on_remove_user_third_party_identifier_callback_mock.assert_called_once()
+        args = on_remove_user_third_party_identifier_callback_mock.call_args[0]
+        self.assertEqual(args, (user_id, "email", "foo@example.com"))
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 3086e1b565..d8dc56261a 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         self.cache = HttpTransactionCache(self.hs)
 
         self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
-        self.mock_key = "foo"
+
+        # Here we make sure that we're setting all the fields that HttpTransactionCache
+        # uses to build the transaction key.
+        self.mock_request = Mock()
+        self.mock_request.path = b"/foo/bar"
+        self.mock_requester = Mock()
+        self.mock_requester.app_service = None
+        self.mock_requester.is_guest = False
+        self.mock_requester.access_token_id = 1234
 
     @defer.inlineCallbacks
     def test_executes_given_function(
         self,
     ) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
-        res = yield self.cache.fetch_or_execute(
-            self.mock_key, cb, "some_arg", keyword="arg"
+        res = yield self.cache.fetch_or_execute_request(
+            self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
         )
         cb.assert_called_once_with("some_arg", keyword="arg")
         self.assertEqual(res, self.mock_http_response)
@@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     ) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
         for i in range(3):  # invoke multiple times
-            res = yield self.cache.fetch_or_execute(
-                self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
+            res = yield self.cache.fetch_or_execute_request(
+                self.mock_request,
+                self.mock_requester,
+                cb,
+                "some_arg",
+                keyword="arg",
+                changing_args=i,
             )
             self.assertEqual(res, self.mock_http_response)
         # expect only a single call to do the work
@@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         @defer.inlineCallbacks
         def test() -> Generator["defer.Deferred[Any]", object, None]:
             with LoggingContext("c") as c1:
-                res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+                res = yield self.cache.fetch_or_execute_request(
+                    self.mock_request, self.mock_requester, cb
+                )
                 self.assertIs(current_context(), c1)
                 self.assertEqual(res, (1, {}))
 
@@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
 
         with LoggingContext("test") as test_context:
             try:
-                yield self.cache.fetch_or_execute(self.mock_key, cb)
+                yield self.cache.fetch_or_execute_request(
+                    self.mock_request, self.mock_requester, cb
+                )
             except Exception as e:
                 self.assertEqual(e.args[0], "boo")
             self.assertIs(current_context(), test_context)
 
-            res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+            res = yield self.cache.fetch_or_execute_request(
+                self.mock_request, self.mock_requester, cb
+            )
             self.assertEqual(res, self.mock_http_response)
             self.assertIs(current_context(), test_context)
 
@@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
 
         with LoggingContext("test") as test_context:
             try:
-                yield self.cache.fetch_or_execute(self.mock_key, cb)
+                yield self.cache.fetch_or_execute_request(
+                    self.mock_request, self.mock_requester, cb
+                )
             except Exception as e:
                 self.assertEqual(e.args[0], "boo")
             self.assertIs(current_context(), test_context)
 
-            res = yield self.cache.fetch_or_execute(self.mock_key, cb)
+            res = yield self.cache.fetch_or_execute_request(
+                self.mock_request, self.mock_requester, cb
+            )
             self.assertEqual(res, self.mock_http_response)
             self.assertIs(current_context(), test_context)
 
     @defer.inlineCallbacks
     def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
-        yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+        yield self.cache.fetch_or_execute_request(
+            self.mock_request, self.mock_requester, cb, "an arg"
+        )
         # should NOT have cleaned up yet
         self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
 
-        yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+        yield self.cache.fetch_or_execute_request(
+            self.mock_request, self.mock_requester, cb, "an arg"
+        )
         # still using cache
         cb.assert_called_once_with("an arg")
 
         self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
 
-        yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
+        yield self.cache.fetch_or_execute_request(
+            self.mock_request, self.mock_requester, cb, "an arg"
+        )
         # no longer using cache
         self.assertEqual(cb.call_count, 2)
         self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 5ec343dd7f..0b4c691318 100644
--- a/tests/rest/client/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
@@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
                 self.room_id, EventTypes.Tombstone, ""
             )
         )
-        self.assertIsNotNone(tombstone_event)
+        assert tombstone_event is not None
         self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
 
         # Check that the new room exists.
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8d6f2b6ff9..9532e5ddc1 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -36,6 +36,7 @@ from urllib.parse import urlencode
 import attr
 from typing_extensions import Literal
 
+from twisted.test.proto_helpers import MemoryReactorClock
 from twisted.web.resource import Resource
 from twisted.web.server import Site
 
@@ -67,6 +68,7 @@ class RestHelper:
     """
 
     hs: HomeServer
+    reactor: MemoryReactorClock
     site: Site
     auth_user_id: Optional[str]
 
@@ -142,7 +144,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -216,7 +218,7 @@ class RestHelper:
             data["reason"] = reason
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             path,
@@ -313,7 +315,7 @@ class RestHelper:
         data.update(extra_data or {})
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -394,7 +396,7 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "PUT",
             path,
@@ -433,7 +435,7 @@ class RestHelper:
             path = path + f"?access_token={tok}"
 
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             path,
@@ -488,7 +490,7 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
+        channel = make_request(self.reactor, self.site, method, path, content)
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
@@ -573,8 +575,8 @@ class RestHelper:
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         channel = make_request(
-            self.hs.get_reactor(),
-            FakeSite(resource, self.hs.get_reactor()),
+            self.reactor,
+            FakeSite(resource, self.reactor),
             "POST",
             path,
             content=image_data,
@@ -603,7 +605,7 @@ class RestHelper:
             expect_code: The return code to expect from attempting the whoami request
         """
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             "account/whoami",
@@ -642,7 +644,7 @@ class RestHelper:
     ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
         """Log in (as a new user) via OIDC
 
-        Returns the result of the final token login.
+        Returns the result of the final token login and the fake authorization grant.
 
         Requires that "oidc_config" in the homeserver config be set appropriately
         (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
@@ -672,10 +674,28 @@ class RestHelper:
         assert m, channel.text_body
         login_token = m.group(1)
 
-        # finally, submit the matrix login token to the login API, which gives us our
-        # matrix access token and device id.
+        return self.login_via_token(login_token, expected_status), grant
+
+    def login_via_token(
+        self,
+        login_token: str,
+        expected_status: int = 200,
+    ) -> JsonDict:
+        """Submit the matrix login token to the login API, which gives us our
+        matrix access token and device id.Log in (as a new user) via OIDC
+
+        Returns the result of the token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "POST",
             "/login",
@@ -684,7 +704,7 @@ class RestHelper:
         assert (
             channel.code == expected_status
         ), f"unexpected status in response: {channel.code}"
-        return channel.json_body, grant
+        return channel.json_body
 
     def auth_via_oidc(
         self,
@@ -805,7 +825,7 @@ class RestHelper:
         with fake_serer.patch_homeserver(hs=self.hs):
             # now hit the callback URI with the right params and a made-up code
             channel = make_request(
-                self.hs.get_reactor(),
+                self.reactor,
                 self.site,
                 "GET",
                 callback_uri,
@@ -849,7 +869,7 @@ class RestHelper:
         # is the easiest way of figuring out what the Host header ought to be set to
         # to keep Synapse happy.
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             uri,
@@ -867,7 +887,7 @@ class RestHelper:
         location = get_location(channel)
         parts = urllib.parse.urlsplit(location)
         channel = make_request(
-            self.hs.get_reactor(),
+            self.reactor,
             self.site,
             "GET",
             urllib.parse.urlunsplit(("", "") + parts[2:]),
@@ -900,9 +920,7 @@ class RestHelper:
             + urllib.parse.urlencode({"session": ui_auth_session_id})
         )
         # hit the redirect url (which will issue a cookie and state)
-        channel = make_request(
-            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
-        )
+        channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
         # that should serve a confirmation page
         assert channel.code == HTTPStatus.OK, channel.text_body
         channel.extract_cookies(cookies)
diff --git a/tests/rest/media/test_media_retention.py b/tests/rest/media/test_media_retention.py
index 23f227aed6..b59d9dfd4d 100644
--- a/tests/rest/media/test_media_retention.py
+++ b/tests/rest/media/test_media_retention.py
@@ -31,7 +31,6 @@ from tests.utils import MockClock
 
 
 class MediaRetentionTestCase(unittest.HomeserverTestCase):
-
     ONE_DAY_IN_MS = 24 * 60 * 60 * 1000
     THIRTY_DAYS_IN_MS = 30 * ONE_DAY_IN_MS
 
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/test_url_preview.py
index 2c321f8d04..e91dc581c2 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/test_url_preview.py
@@ -26,8 +26,8 @@ from twisted.internet.interfaces import IAddress, IResolutionReceiver
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
 
 from synapse.config.oembed import OEmbedEndpointConfig
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
-from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
+from synapse.rest.media.media_repository_resource import MediaRepositoryResource
+from synapse.rest.media.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -58,7 +58,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
     )
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-
         config = self.default_config()
         config["url_preview_enabled"] = True
         config["max_spider_size"] = 9999999
@@ -83,7 +82,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         config["media_store_path"] = self.media_store_path
 
         provider_config = {
-            "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
+            "module": "synapse.media.storage_provider.FileStorageProviderBackend",
             "store_local": True,
             "store_synchronous": False,
             "store_remote": True,
@@ -118,7 +117,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-
         self.media_repo = hs.get_media_repository_resource()
         self.preview_url = self.media_repo.children[b"preview_url"]
 
@@ -133,7 +131,6 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 addressTypes: Optional[Sequence[Type[IAddress]]] = None,
                 transportSemantics: str = "TCP",
             ) -> IResolutionReceiver:
-
                 resolution = HostResolution(hostName)
                 resolutionReceiver.resolutionBegan(resolution)
                 if hostName not in self.lookups:
@@ -660,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """If the preview image doesn't exist, ensure some data is returned."""
         self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
-        end_content = (
+        result = (
             b"""<html><body><img src="http://cdn.matrix.org/foo.jpg"></body></html>"""
         )
 
@@ -681,8 +678,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
                 b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
                 b'Content-Type: text/html; charset="utf8"\r\n\r\n'
             )
-            % (len(end_content),)
-            + end_content
+            % (len(result),)
+            + result
         )
 
         self.pump()
@@ -691,6 +688,44 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # The image should not be in the result.
         self.assertNotIn("og:image", channel.json_body)
 
+    def test_oembed_failure(self) -> None:
+        """If the autodiscovered oEmbed URL fails, ensure some data is returned."""
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
+
+        result = b"""
+        <title>oEmbed Autodiscovery Fail</title>
+        <link rel="alternate" type="application/json+oembed"
+            href="http://example.com/oembed?url=http%3A%2F%2Fmatrix.org&format=json"
+            title="matrixdotorg" />
+        """
+
+        channel = self.make_request(
+            "GET",
+            "preview_url?url=http://matrix.org",
+            shorthand=False,
+            await_result=False,
+        )
+        self.pump()
+
+        client = self.reactor.tcpClients[0][2].buildProtocol(None)
+        server = AccumulatingProtocol()
+        server.makeConnection(FakeTransport(client, self.reactor))
+        client.makeConnection(FakeTransport(server, self.reactor))
+        client.dataReceived(
+            (
+                b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+            )
+            % (len(result),)
+            + result
+        )
+
+        self.pump()
+        self.assertEqual(channel.code, 200)
+
+        # The image should not be in the result.
+        self.assertEqual(channel.json_body["og:title"], "oEmbed Autodiscovery Fail")
+
     def test_data_url(self) -> None:
         """
         Requesting to preview a data URL is not supported.
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 22f99c6ab1..3285f2433c 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,29 +12,33 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List
+from typing import List, Optional
 from unittest.mock import Mock, patch
 
 from synapse._scripts.register_new_matrix_user import request_registration
+from synapse.types import JsonDict
 
 from tests.unittest import TestCase
 
 
 class RegisterTestCase(TestCase):
-    def test_success(self):
+    def test_success(self) -> None:
         """
         The script will fetch a nonce, and then generate a MAC with it, and then
         post that MAC.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 200
             r.json = lambda: {"nonce": "a"}
             return r
 
-        def post(url, json=None, verify=None):
+        def post(
+            url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+        ) -> Mock:
             # Make sure we are sent the correct info
+            assert json is not None
             self.assertEqual(json["username"], "user")
             self.assertEqual(json["password"], "pass")
             self.assertEqual(json["nonce"], "a")
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
         # sys.exit shouldn't have been called.
         self.assertEqual(err_code, [])
 
-    def test_failure_nonce(self):
+    def test_failure_nonce(self) -> None:
         """
         If the script fails to fetch a nonce, it throws an error and quits.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 404
             r.reason = "Not Found"
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
         self.assertIn("ERROR! Received 404 Not Found", out)
         self.assertNotIn("Success!", out)
 
-    def test_failure_post(self):
+    def test_failure_post(self) -> None:
         """
         The script will fetch a nonce, and then if the final POST fails, will
         report an error and quit.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 200
             r.json = lambda: {"nonce": "a"}
             return r
 
-        def post(url, json=None, verify=None):
+        def post(
+            url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+        ) -> Mock:
             # Make sure we are sent the correct info
+            assert json is not None
             self.assertEqual(json["username"], "user")
             self.assertEqual(json["password"], "pass")
             self.assertEqual(json["nonce"], "a")
diff --git a/tests/server.py b/tests/server.py
index b1730fcc8d..5de9722766 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,20 +22,25 @@ import warnings
 from collections import deque
 from io import SEEK_END, BytesIO
 from typing import (
+    Any,
+    Awaitable,
     Callable,
     Dict,
     Iterable,
     List,
     MutableMapping,
     Optional,
+    Sequence,
     Tuple,
     Type,
+    TypeVar,
     Union,
+    cast,
 )
 from unittest.mock import Mock
 
 import attr
-from typing_extensions import Deque
+from typing_extensions import Deque, ParamSpec
 from zope.interface import implementer
 
 from twisted.internet import address, threads, udp
@@ -44,8 +49,10 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import (
     IAddress,
+    IConnector,
     IConsumer,
     IHostnameResolver,
+    IProducer,
     IProtocol,
     IPullProducer,
     IPushProducer,
@@ -54,6 +61,8 @@ from twisted.internet.interfaces import (
     IResolverSimple,
     ITransport,
 )
+from twisted.internet.protocol import ClientFactory, DatagramProtocol
+from twisted.python import threadpool
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http_headers import Headers
@@ -61,6 +70,7 @@ from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -70,7 +80,7 @@ from synapse.logging.context import ContextResourceUsage
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import PostgresEngine, create_engine
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import Clock
 
 from tests.utils import (
@@ -88,6 +98,9 @@ from tests.utils import (
 
 logger = logging.getLogger(__name__)
 
+R = TypeVar("R")
+P = ParamSpec("P")
+
 # the type of thing that can be passed into `make_request` in the headers list
 CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
 
@@ -98,12 +111,14 @@ class TimedOutException(Exception):
     """
 
 
-@implementer(IConsumer)
+@implementer(ITransport, IPushProducer, IConsumer)
 @attr.s(auto_attribs=True)
 class FakeChannel:
     """
     A fake Twisted Web Channel (the part that interfaces with the
     wire).
+
+    See twisted.web.http.HTTPChannel.
     """
 
     site: Union[Site, "FakeSite"]
@@ -142,7 +157,7 @@ class FakeChannel:
 
         Raises an exception if the request has not yet completed.
         """
-        if not self.is_finished:
+        if not self.is_finished():
             raise Exception("Request not yet completed")
         return self.result["body"].decode("utf8")
 
@@ -165,27 +180,36 @@ class FakeChannel:
             h.addRawHeader(*i)
         return h
 
-    def writeHeaders(self, version, code, reason, headers):
+    def writeHeaders(
+        self, version: bytes, code: bytes, reason: bytes, headers: Headers
+    ) -> None:
         self.result["version"] = version
         self.result["code"] = code
         self.result["reason"] = reason
         self.result["headers"] = headers
 
-    def write(self, content: bytes) -> None:
-        assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+    def write(self, data: bytes) -> None:
+        assert isinstance(data, bytes), "Should be bytes! " + repr(data)
 
         if "body" not in self.result:
             self.result["body"] = b""
 
-        self.result["body"] += content
+        self.result["body"] += data
+
+    def writeSequence(self, data: Iterable[bytes]) -> None:
+        for x in data:
+            self.write(x)
+
+    def loseConnection(self) -> None:
+        self.unregisterProducer()
+        self.transport.loseConnection()
 
     # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
-    def registerProducer(  # type: ignore[override]
-        self,
-        producer: Union[IPullProducer, IPushProducer],
-        streaming: bool,
-    ) -> None:
-        self._producer = producer
+    def registerProducer(self, producer: IProducer, streaming: bool) -> None:
+        # TODO This should ensure that the IProducer is an IPushProducer or
+        # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
+        # implement those, but doesn't declare it.
+        self._producer = cast(Union[IPushProducer, IPullProducer], producer)
         self.producerStreaming = streaming
 
         def _produce() -> None:
@@ -202,6 +226,16 @@ class FakeChannel:
 
         self._producer = None
 
+    def stopProducing(self) -> None:
+        if self._producer is not None:
+            self._producer.stopProducing()
+
+    def pauseProducing(self) -> None:
+        raise NotImplementedError()
+
+    def resumeProducing(self) -> None:
+        raise NotImplementedError()
+
     def requestDone(self, _self: Request) -> None:
         self.result["done"] = True
         if isinstance(_self, SynapseRequest):
@@ -281,12 +315,12 @@ class FakeSite:
         self.reactor = reactor
         self.experimental_cors_msc3886 = experimental_cors_msc3886
 
-    def getResourceFor(self, request):
+    def getResourceFor(self, request: Request) -> IResource:
         return self._resource
 
 
 def make_request(
-    reactor,
+    reactor: MemoryReactorClock,
     site: Union[Site, FakeSite],
     method: Union[bytes, str],
     path: Union[bytes, str],
@@ -401,25 +435,29 @@ def make_request(
     return channel
 
 
-@implementer(IReactorPluggableNameResolver)
+# ISynapseReactor implies IReactorPluggableNameResolver, but explicitly
+# marking this as an implementer of the latter seems to keep mypy-zope happier.
+@implementer(IReactorPluggableNameResolver, ISynapseReactor)
 class ThreadedMemoryReactorClock(MemoryReactorClock):
     """
     A MemoryReactorClock that supports callFromThread.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self.threadpool = ThreadPool(self)
 
         self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
-        self._udp = []
+        self._udp: List[udp.Port] = []
         self.lookups: Dict[str, str] = {}
-        self._thread_callbacks: Deque[Callable[[], None]] = deque()
+        self._thread_callbacks: Deque[Callable[..., R]] = deque()
 
         lookups = self.lookups
 
         @implementer(IResolverSimple)
         class FakeResolver:
-            def getHostByName(self, name, timeout=None):
+            def getHostByName(
+                self, name: str, timeout: Optional[Sequence[int]] = None
+            ) -> "Deferred[str]":
                 if name not in lookups:
                     return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
                 return succeed(lookups[name])
@@ -430,25 +468,44 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
         raise NotImplementedError()
 
-    def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
+    def listenUDP(
+        self,
+        port: int,
+        protocol: DatagramProtocol,
+        interface: str = "",
+        maxPacketSize: int = 8196,
+    ) -> udp.Port:
         p = udp.Port(port, protocol, interface, maxPacketSize, self)
         p.startListening()
         self._udp.append(p)
         return p
 
-    def callFromThread(self, callback, *args, **kwargs):
+    def callFromThread(
+        self, callable: Callable[..., Any], *args: object, **kwargs: object
+    ) -> None:
         """
         Make the callback fire in the next reactor iteration.
         """
-        cb = lambda: callback(*args, **kwargs)
+        cb = lambda: callable(*args, **kwargs)
         # it's not safe to call callLater() here, so we append the callback to a
         # separate queue.
         self._thread_callbacks.append(cb)
 
-    def getThreadPool(self):
-        return self.threadpool
+    def callInThread(
+        self, callable: Callable[..., Any], *args: object, **kwargs: object
+    ) -> None:
+        raise NotImplementedError()
+
+    def suggestThreadPoolSize(self, size: int) -> None:
+        raise NotImplementedError()
+
+    def getThreadPool(self) -> "threadpool.ThreadPool":
+        # Cast to match super-class.
+        return cast(threadpool.ThreadPool, self.threadpool)
 
-    def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
+    def add_tcp_client_callback(
+        self, host: str, port: int, callback: Callable[[], None]
+    ) -> None:
         """Add a callback that will be invoked when we receive a connection
         attempt to the given IP/port using `connectTCP`.
 
@@ -457,7 +514,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         """
         self._tcp_callbacks[(host, port)] = callback
 
-    def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
+    def connectTCP(
+        self,
+        host: str,
+        port: int,
+        factory: ClientFactory,
+        timeout: float = 30,
+        bindAddress: Optional[Tuple[str, int]] = None,
+    ) -> IConnector:
         """Fake L{IReactorTCP.connectTCP}."""
 
         conn = super().connectTCP(
@@ -470,7 +534,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         return conn
 
-    def advance(self, amount):
+    def advance(self, amount: float) -> None:
         # first advance our reactor's time, and run any "callLater" callbacks that
         # makes ready
         super().advance(amount)
@@ -498,25 +562,33 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 class ThreadPool:
     """
     Threadless thread pool.
+
+    See twisted.python.threadpool.ThreadPool
     """
 
-    def __init__(self, reactor):
+    def __init__(self, reactor: IReactorTime):
         self._reactor = reactor
 
-    def start(self):
+    def start(self) -> None:
         pass
 
-    def stop(self):
+    def stop(self) -> None:
         pass
 
-    def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
-        def _(res):
+    def callInThreadWithCallback(
+        self,
+        onResult: Callable[[bool, Union[Failure, R]], None],
+        function: Callable[P, R],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> "Deferred[None]":
+        def _(res: Any) -> None:
             if isinstance(res, Failure):
                 onResult(False, res)
             else:
                 onResult(True, res)
 
-        d = Deferred()
+        d: "Deferred[None]" = Deferred()
         d.addCallback(lambda x: function(*args, **kwargs))
         d.addBoth(_)
         self._reactor.callLater(0, d.callback, True)
@@ -533,7 +605,9 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
     for database in server.get_datastores().databases:
         pool = database._db_pool
 
-        def runWithConnection(func, *args, **kwargs):
+        def runWithConnection(
+            func: Callable[..., R], *args: Any, **kwargs: Any
+        ) -> Awaitable[R]:
             return threads.deferToThreadPool(
                 pool._reactor,
                 pool.threadpool,
@@ -543,20 +617,23 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
                 **kwargs,
             )
 
-        def runInteraction(interaction, *args, **kwargs):
+        def runInteraction(
+            desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+        ) -> Awaitable[R]:
             return threads.deferToThreadPool(
                 pool._reactor,
                 pool.threadpool,
                 pool._runInteraction,
-                interaction,
+                desc,
+                func,
                 *args,
                 **kwargs,
             )
 
-        pool.runWithConnection = runWithConnection
-        pool.runInteraction = runInteraction
+        pool.runWithConnection = runWithConnection  # type: ignore[assignment]
+        pool.runInteraction = runInteraction  # type: ignore[assignment]
         # Replace the thread pool with a threadless 'thread' pool
-        pool.threadpool = ThreadPool(clock._reactor)
+        pool.threadpool = ThreadPool(clock._reactor)  # type: ignore[assignment]
         pool.running = True
 
     # We've just changed the Databases to run DB transactions on the same
@@ -571,7 +648,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
 
 
 @implementer(ITransport)
-@attr.s(cmp=False)
+@attr.s(cmp=False, auto_attribs=True)
 class FakeTransport:
     """
     A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -586,48 +663,50 @@ class FakeTransport:
     If you want bidirectional communication, you'll need two instances.
     """
 
-    other = attr.ib()
+    other: IProtocol
     """The Protocol object which will receive any data written to this transport.
-
-    :type: twisted.internet.interfaces.IProtocol
     """
 
-    _reactor = attr.ib()
+    _reactor: IReactorTime
     """Test reactor
-
-    :type: twisted.internet.interfaces.IReactorTime
     """
 
-    _protocol = attr.ib(default=None)
+    _protocol: Optional[IProtocol] = None
     """The Protocol which is producing data for this transport. Optional, but if set
     will get called back for connectionLost() notifications etc.
     """
 
-    _peer_address: Optional[IAddress] = attr.ib(default=None)
+    _peer_address: IAddress = attr.Factory(
+        lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
+    )
     """The value to be returned by getPeer"""
 
-    _host_address: Optional[IAddress] = attr.ib(default=None)
+    _host_address: IAddress = attr.Factory(
+        lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
+    )
     """The value to be returned by getHost"""
 
     disconnecting = False
     disconnected = False
     connected = True
-    buffer = attr.ib(default=b"")
-    producer = attr.ib(default=None)
-    autoflush = attr.ib(default=True)
+    buffer: bytes = b""
+    producer: Optional[IPushProducer] = None
+    autoflush: bool = True
 
-    def getPeer(self) -> Optional[IAddress]:
+    def getPeer(self) -> IAddress:
         return self._peer_address
 
-    def getHost(self) -> Optional[IAddress]:
+    def getHost(self) -> IAddress:
         return self._host_address
 
-    def loseConnection(self, reason=None):
+    def loseConnection(self) -> None:
         if not self.disconnecting:
-            logger.info("FakeTransport: loseConnection(%s)", reason)
+            logger.info("FakeTransport: loseConnection()")
             self.disconnecting = True
             if self._protocol:
-                self._protocol.connectionLost(reason)
+                self._protocol.connectionLost(
+                    Failure(RuntimeError("FakeTransport.loseConnection()"))
+                )
 
             # if we still have data to write, delay until that is done
             if self.buffer:
@@ -638,38 +717,38 @@ class FakeTransport:
                 self.connected = False
                 self.disconnected = True
 
-    def abortConnection(self):
+    def abortConnection(self) -> None:
         logger.info("FakeTransport: abortConnection()")
 
         if not self.disconnecting:
             self.disconnecting = True
             if self._protocol:
-                self._protocol.connectionLost(None)
+                self._protocol.connectionLost(None)  # type: ignore[arg-type]
 
         self.disconnected = True
 
-    def pauseProducing(self):
+    def pauseProducing(self) -> None:
         if not self.producer:
             return
 
         self.producer.pauseProducing()
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         if not self.producer:
             return
         self.producer.resumeProducing()
 
-    def unregisterProducer(self):
+    def unregisterProducer(self) -> None:
         if not self.producer:
             return
 
         self.producer = None
 
-    def registerProducer(self, producer, streaming):
+    def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
         self.producer = producer
         self.producerStreaming = streaming
 
-        def _produce():
+        def _produce() -> None:
             if not self.producer:
                 # we've been unregistered
                 return
@@ -681,7 +760,7 @@ class FakeTransport:
         if not streaming:
             self._reactor.callLater(0.0, _produce)
 
-    def write(self, byt):
+    def write(self, byt: bytes) -> None:
         if self.disconnecting:
             raise Exception("Writing to disconnecting FakeTransport")
 
@@ -693,11 +772,11 @@ class FakeTransport:
         if self.autoflush:
             self._reactor.callLater(0.0, self.flush)
 
-    def writeSequence(self, seq):
+    def writeSequence(self, seq: Iterable[bytes]) -> None:
         for x in seq:
             self.write(x)
 
-    def flush(self, maxbytes=None):
+    def flush(self, maxbytes: Optional[int] = None) -> None:
         if not self.buffer:
             # nothing to do. Don't write empty buffers: it upsets the
             # TLSMemoryBIOProtocol
@@ -748,17 +827,17 @@ def connect_client(
 
 
 class TestHomeServer(HomeServer):
-    DATASTORE_CLASS = DataStore
+    DATASTORE_CLASS = DataStore  # type: ignore[assignment]
 
 
 def setup_test_homeserver(
-    cleanup_func,
-    name="test",
-    config=None,
-    reactor=None,
+    cleanup_func: Callable[[Callable[[], None]], None],
+    name: str = "test",
+    config: Optional[HomeServerConfig] = None,
+    reactor: Optional[ISynapseReactor] = None,
     homeserver_to_use: Type[HomeServer] = TestHomeServer,
-    **kwargs,
-):
+    **kwargs: Any,
+) -> HomeServer:
     """
     Setup a homeserver suitable for running tests against.  Keyword arguments
     are passed to the Homeserver constructor.
@@ -773,13 +852,14 @@ def setup_test_homeserver(
     HomeserverTestCase.
     """
     if reactor is None:
-        from twisted.internet import reactor
+        from twisted.internet import reactor as _reactor
+
+        reactor = cast(ISynapseReactor, _reactor)
 
     if config is None:
         config = default_config(name, parse=True)
 
     config.caches.resize_all_caches()
-    config.ldap_enabled = False
 
     if "clock" not in kwargs:
         kwargs["clock"] = MockClock()
@@ -830,6 +910,8 @@ def setup_test_homeserver(
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
     if isinstance(db_engine, PostgresEngine):
+        import psycopg2.extensions
+
         db_conn = db_engine.module.connect(
             database=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
@@ -837,6 +919,7 @@ def setup_test_homeserver(
             port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
         )
+        assert isinstance(db_conn, psycopg2.extensions.connection)
         db_conn.autocommit = True
         cur = db_conn.cursor()
         cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
@@ -865,14 +948,15 @@ def setup_test_homeserver(
         hs.setup_background_tasks()
 
     if isinstance(db_engine, PostgresEngine):
-        database = hs.get_datastores().databases[0]
+        database_pool = hs.get_datastores().databases[0]
 
         # We need to do cleanup on PostgreSQL
-        def cleanup():
+        def cleanup() -> None:
             import psycopg2
+            import psycopg2.extensions
 
             # Close all the db pools
-            database._db_pool.close()
+            database_pool._db_pool.close()
 
             dropped = False
 
@@ -884,6 +968,7 @@ def setup_test_homeserver(
                 port=POSTGRES_PORT,
                 password=POSTGRES_PASSWORD,
             )
+            assert isinstance(db_conn, psycopg2.extensions.connection)
             db_conn.autocommit = True
             cur = db_conn.cursor()
 
@@ -916,23 +1001,23 @@ def setup_test_homeserver(
     # Need to let the HS build an auth handler and then mess with it
     # because AuthHandler's constructor requires the HS, so we can't make one
     # beforehand and pass it in to the HS's constructor (chicken / egg)
-    async def hash(p):
+    async def hash(p: str) -> str:
         return hashlib.md5(p.encode("utf8")).hexdigest()
 
-    hs.get_auth_handler().hash = hash
+    hs.get_auth_handler().hash = hash  # type: ignore[assignment]
 
-    async def validate_hash(p, h):
+    async def validate_hash(p: str, h: str) -> bool:
         return hashlib.md5(p.encode("utf8")).hexdigest() == h
 
-    hs.get_auth_handler().validate_hash = validate_hash
+    hs.get_auth_handler().validate_hash = validate_hash  # type: ignore[assignment]
 
     # Make the threadpool and database transactions synchronous for testing.
     _make_test_homeserver_synchronous(hs)
 
     # Load any configured modules into the homeserver
     module_api = hs.get_module_api()
-    for module, config in hs.config.modules.loaded_modules:
-        module(config=config, api=module_api)
+    for module, module_config in hs.config.modules.loaded_modules:
+        module(config=module_config, api=module_api)
 
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 58b399a043..3fdf5a6d52 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -14,14 +14,17 @@
 
 import os
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
 
 class ConsentNoticesTests(unittest.HomeserverTestCase):
-
     servlets = [
         sync.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
@@ -29,8 +32,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         tmpdir = self.mktemp()
         os.mkdir(tmpdir)
         self.consent_notice_message = "consent %(consent_uri)s"
@@ -53,15 +55,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
             "room_name": "Server Notices",
         }
 
-        hs = self.setup_test_homeserver(config=config)
-
-        return hs
+        return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("bob", "abc123")
         self.access_token = self.login("bob", "abc123")
 
-    def test_get_sync_message(self):
+    def test_get_sync_message(self) -> None:
         """
         When user consent server notices are enabled, a sync will cause a notice
         to fire (in a room which the user is invited to). The notice contains
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index dadc6efcbf..d2bfa53eda 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -24,6 +24,8 @@ from synapse.server import HomeServer
 from synapse.server_notices.resource_limits_server_notices import (
     ResourceLimitsServerNotices,
 )
+from synapse.server_notices.server_notices_sender import ServerNoticesSender
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -33,7 +35,7 @@ from tests.utils import default_config
 
 
 class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = default_config("test")
 
         config.update(
@@ -57,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.server_notices_sender = self.hs.get_server_notices_sender()
+        server_notices_sender = self.hs.get_server_notices_sender()
+        assert isinstance(server_notices_sender, ServerNoticesSender)
 
         # relying on [1] is far from ideal, but the only case where
         # ResourceLimitsServerNotices class needs to be isolated is this test,
         # general code should never have a reason to do so ...
-        self._rlsn = self.server_notices_sender._server_notices[1]
-        if not isinstance(self._rlsn, ResourceLimitsServerNotices):
-            raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+        rlsn = list(server_notices_sender._server_notices)[1]
+        assert isinstance(rlsn, ResourceLimitsServerNotices)
+        self._rlsn = rlsn
 
         self._rlsn._store.user_last_seen_monthly_active = Mock(
             return_value=make_awaitable(1000)
@@ -86,39 +89,43 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))  # type: ignore[assignment]
 
     @override_config({"hs_disabled": True})
-    def test_maybe_send_server_notice_disabled_hs(self):
+    def test_maybe_send_server_notice_disabled_hs(self) -> None:
         """If the HS is disabled, we should not send notices"""
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         self._send_notice.assert_not_called()
 
     @override_config({"limit_usage_by_mau": False})
-    def test_maybe_send_server_notice_to_user_flag_off(self):
+    def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
         """If mau limiting is disabled, we should not send notices"""
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+    def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
         """Test when user has blocked notice, but should have it removed"""
 
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None)
         )
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(
+        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
             return_value=make_awaitable({"123": mock_event})
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         # Would be better to check the content, but once == remove blocking event
-        self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
+        maybe_get_notice_room_for_user = (
+            self._rlsn._server_notices_manager.maybe_get_notice_room_for_user
+        )
+        assert isinstance(maybe_get_notice_room_for_user, Mock)
+        maybe_get_notice_room_for_user.assert_called_once()
         self._send_notice.assert_called_once()
 
-    def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+    def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
         """
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(403, "foo"),
         )
@@ -126,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(
+        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
             return_value=make_awaitable({"123": mock_event})
         )
 
@@ -134,11 +141,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
 
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+    def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
         """
         Test when user does not have blocked notice, but should have one
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(403, "foo"),
         )
@@ -147,11 +154,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         # Would be better to check contents, but 2 calls == set blocking event
         self.assertEqual(self._send_notice.call_count, 2)
 
-    def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+    def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None)
         )
 
@@ -159,12 +166,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
 
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+    def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
         """
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None)
         )
         self._rlsn._store.user_last_seen_monthly_active = Mock(
@@ -175,12 +182,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._send_notice.assert_not_called()
 
     @override_config({"mau_limit_alerting": False})
-    def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+    def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
+        self,
+    ) -> None:
         """
         Test that when server is over MAU limit and alerting is suppressed, then
         an alert message is not sent into the room
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
@@ -191,11 +200,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self.assertEqual(self._send_notice.call_count, 0)
 
     @override_config({"mau_limit_alerting": False})
-    def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+    def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
@@ -207,26 +216,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self.assertEqual(self._send_notice.call_count, 2)
 
     @override_config({"mau_limit_alerting": False})
-    def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+    def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
+        self,
+    ) -> None:
         """
         When the room is already in a blocked state, test that when alerting
         is suppressed that the room is returned to an unblocked state.
         """
-        self._rlsn._auth_blocking.check_auth_blocking = Mock(
+        self._rlsn._auth_blocking.check_auth_blocking = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None),
             side_effect=ResourceLimitError(
                 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
             ),
         )
 
-        self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
+        self._rlsn._is_room_currently_blocked = Mock(  # type: ignore[assignment]
             return_value=make_awaitable((True, []))
         )
 
         mock_event = Mock(
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
-        self._rlsn._store.get_events = Mock(
+        self._rlsn._store.get_events = Mock(  # type: ignore[assignment]
             return_value=make_awaitable({"123": mock_event})
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -242,7 +253,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         c = super().default_config()
         c["server_notices"] = {
             "system_mxid_localpart": "server",
@@ -257,20 +268,22 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = self.hs.get_datastores().main
-        self.server_notices_sender = self.hs.get_server_notices_sender()
         self.server_notices_manager = self.hs.get_server_notices_manager()
         self.event_source = self.hs.get_event_sources()
 
+        server_notices_sender = self.hs.get_server_notices_sender()
+        assert isinstance(server_notices_sender, ServerNoticesSender)
+
         # relying on [1] is far from ideal, but the only case where
         # ResourceLimitsServerNotices class needs to be isolated is this test,
         # general code should never have a reason to do so ...
-        self._rlsn = self.server_notices_sender._server_notices[1]
-        if not isinstance(self._rlsn, ResourceLimitsServerNotices):
-            raise Exception("Failed to find reference to ResourceLimitsServerNotices")
+        rlsn = list(server_notices_sender._server_notices)[1]
+        assert isinstance(rlsn, ResourceLimitsServerNotices)
+        self._rlsn = rlsn
 
         self.user_id = "@user_id:test"
 
-    def test_server_notice_only_sent_once(self):
+    def test_server_notice_only_sent_once(self) -> None:
         self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
 
         self.store.user_last_seen_monthly_active = Mock(
@@ -306,7 +319,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         self.assertEqual(count, 1)
 
-    def test_no_invite_without_notice(self):
+    def test_no_invite_without_notice(self) -> None:
         """Tests that a user doesn't get invited to a server notices room without a
         server notice being sent.
 
@@ -328,7 +341,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         m.assert_called_once_with(user_id)
 
-    def test_invite_with_notice(self):
+    def test_invite_with_notice(self) -> None:
         """Tests that, if the MAU limit is hit, the server notices user invites each user
         to a room in which it has sent a notice.
         """
diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py
index 373707b275..b6d5c474b0 100644
--- a/tests/storage/databases/main/test_deviceinbox.py
+++ b/tests/storage/databases/main/test_deviceinbox.py
@@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase
 
 
 class DeviceInboxBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         devices.register_servlets,
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 9f33afcca0..9606ecc43b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
         # Persist the event which should invalidate or prefill the
         # `have_seen_event` cache so we don't return stale values.
         persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
         self.get_success(
             persistence.persist_event(
                 event,
diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py
index ac77aec003..71db47405e 100644
--- a/tests/storage/databases/main/test_receipts.py
+++ b/tests/storage/databases/main/test_receipts.py
@@ -26,7 +26,6 @@ from tests.unittest import HomeserverTestCase
 
 
 class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         room.register_servlets,
@@ -62,6 +61,7 @@ class ReceiptsBackgroundUpdateStoreTestCase(HomeserverTestCase):
                 keys and expected receipt key-values after duplicate receipts have been
                 removed.
         """
+
         # First, undo the background update.
         def drop_receipts_unique_index(txn: LoggingTransaction) -> None:
             txn.execute(f"DROP INDEX IF EXISTS {index_name}")
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 3108ca3444..dbd8f3a85e 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -27,7 +27,6 @@ from tests.unittest import HomeserverTestCase
 
 
 class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         room.register_servlets,
diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py
index 1bfd11ceae..b12691a9d3 100644
--- a/tests/storage/test_account_data.py
+++ b/tests/storage/test_account_data.py
@@ -140,3 +140,25 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase):
         # No one ignores the user now.
         self.assert_ignored(self.user, set())
         self.assert_ignorers("@other:test", set())
+
+    def test_ignoring_users_with_latest_stream_ids(self) -> None:
+        """Test that ignoring users updates the latest stream ID for the ignored
+        user list account data."""
+
+        def get_latest_ignore_streampos(user_id: str) -> Optional[int]:
+            return self.get_success(
+                self.store.get_latest_stream_id_for_global_account_data_by_type_for_user(
+                    user_id, AccountDataTypes.IGNORED_USER_LIST
+                )
+            )
+
+        self.assertIsNone(get_latest_ignore_streampos("@user:test"))
+
+        self._update_ignore_list("@other:test", "@another:remote")
+
+        self.assertEqual(get_latest_ignore_streampos("@user:test"), 2)
+
+        # Add one user, remove one user, and leave one user.
+        self._update_ignore_list("@foo:test", "@another:remote")
+
+        self.assertEqual(get_latest_ignore_streampos("@user:test"), 3)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index d570684c99..7de109966d 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -43,8 +43,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         # Create a test user and room
         self.user = UserID("alice", "test")
         self.requester = create_requester(self.user)
-        info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
-        self.room_id = info["room_id"]
+        self.room_id, _, _ = self.get_success(
+            self.room_creator.create_room(self.requester, {})
+        )
 
     def run_background_update(self) -> None:
         """Re run the background update to clean up the extremities."""
@@ -275,10 +276,9 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         self.user = UserID.from_string(self.register_user("user1", "password"))
         self.token1 = self.login("user1", "password")
         self.requester = create_requester(self.user)
-        info, _ = self.get_success(
+        self.room_id, _, _ = self.get_success(
             self.room_creator.create_room(self.requester, {"visibility": "public"})
         )
-        self.room_id = info["room_id"]
         self.event_creator = homeserver.get_event_creation_handler()
         homeserver.config.consent.user_consent_version = self.CONSENT_VERSION
 
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 7f7f4ef892..cd0079871c 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -656,7 +656,6 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
 
 class ClientIpAuthTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..e39b63edac 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
         """
 
         persist_events_store = self.hs.get_datastores().persist_events
+        assert persist_events_store is not None
 
         for e in events:
             e.internal_metadata.stream_ordering = self._next_stream_ordering
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
         def _persist(txn: LoggingTransaction) -> None:
             # We need to persist the events to the events and state_events
             # tables.
+            assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
                 [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@@ -415,7 +417,6 @@ class EventChainStoreTestCase(HomeserverTestCase):
     def fetch_chains(
         self, events: List[EventBase]
     ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
-
         # Fetch the map from event ID -> (chain ID, sequence number)
         rows = self.get_success(
             self.store.db_pool.simple_select_many_batch(
@@ -490,7 +491,6 @@ class LinkMapTestCase(unittest.TestCase):
 
 
 class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
-
     servlets = [
         admin.register_servlets,
         room.register_servlets,
@@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_prev_events_for_room(room_id)
         )
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_handler.create_event(
                 self.requester,
                 {
@@ -535,14 +535,17 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 prev_event_ids=latest_event_ids,
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
             )
         )
-        state1 = set(self.get_success(context.get_current_state_ids()).values())
+        state_ids1 = self.get_success(context.get_current_state_ids())
+        assert state_ids1 is not None
+        state1 = set(state_ids1.values())
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_handler.create_event(
                 self.requester,
                 {
@@ -555,12 +558,15 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 prev_event_ids=latest_event_ids,
             )
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             event_handler.handle_new_client_event(
                 self.requester, events_and_context=[(event, context)]
             )
         )
-        state2 = set(self.get_success(context.get_current_state_ids()).values())
+        state_ids2 = self.get_success(context.get_current_state_ids())
+        assert state_ids2 is not None
+        state2 = set(state_ids2.values())
 
         # Delete the chain cover info.
 
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 7fd3e01364..3e1984c15c 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
+        persist_events = hs.get_datastores().persist_events
+        assert persist_events is not None
+        self.persist_events = persist_events
 
     def test_get_prev_events_for_room(self) -> None:
         room_id = "@ROOM:local"
@@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                     },
                 )
 
-            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+            self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
                     cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 )
 
             # Insert all events apart from 'B'
-            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+            self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
                     cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
@@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 updatevalues={"has_auth_chain_index": False},
             )
 
-            self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+            self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
             )
@@ -669,7 +672,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         complete_event_dict_map: Dict[str, JsonDict] = {}
         stream_ordering = 0
-        for (event_id, prev_event_ids) in event_graph.items():
+        for event_id, prev_event_ids in event_graph.items():
             depth = depth_map[event_id]
 
             complete_event_dict_map[event_id] = {
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a91411168c..6897addbd3 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -33,8 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
         events = [(3, 2), (6, 2), (4, 6)]
 
         for event_count, extrems in events:
-            info, _ = self.get_success(room_creator.create_room(requester, {}))
-            room_id = info["room_id"]
+            room_id, _, _ = self.get_success(room_creator.create_room(requester, {}))
 
             last_event = None
 
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 76c06a9d1e..aa19c3bd30 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -774,7 +774,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         self.assertEqual(r, 3)
 
         # add a bunch of dummy events to the events table
-        for (stream_ordering, ts) in (
+        for stream_ordering, ts in (
             (3, 110),
             (4, 120),
             (5, 120),
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 05661a537d..e67dd0589d 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
         self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
     ) -> None:
         self.state = self.hs.get_state_handler()
-        self._persistence = self.hs.get_storage_controllers().persistence
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self._persistence = persistence
         self._state_storage_controller = self.hs.get_storage_controllers().state
         self.store = self.hs.get_datastores().main
 
@@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
         self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
     ) -> None:
         self.state = self.hs.get_state_handler()
-        self._persistence = self.hs.get_storage_controllers().persistence
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self._persistence = persistence
         self.store = self.hs.get_datastores().main
 
     def test_remote_user_rooms_cache_invalidated(self) -> None:
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index aa4b5bd3b1..ba68171ad7 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -16,8 +16,6 @@ import signedjson.key
 import signedjson.types
 import unpaddedbase64
 
-from twisted.internet.defer import Deferred
-
 from synapse.storage.keys import FetchKeyResult
 
 import tests.unittest
@@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:KEY_ID_2"
-        d = store.store_server_verify_keys(
-            "from_server",
-            10,
-            [
-                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
-                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
-            ],
+        self.get_success(
+            store.store_server_verify_keys(
+                "from_server",
+                10,
+                [
+                    ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                    ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
+                ],
+            )
         )
-        self.get_success(d)
 
-        d = store.get_server_verify_keys(
-            [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
+        res = self.get_success(
+            store.get_server_verify_keys(
+                [
+                    ("server1", key_id_1),
+                    ("server1", key_id_2),
+                    ("server1", "ed25519:key3"),
+                ]
+            )
         )
-        res = self.get_success(d)
 
         self.assertEqual(len(res.keys()), 3)
         res1 = res[("server1", key_id_1)]
@@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         key_id_1 = "ed25519:key1"
         key_id_2 = "ed25519:key2"
 
-        d = store.store_server_verify_keys(
-            "from_server",
-            0,
-            [
-                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
-                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
-            ],
+        self.get_success(
+            store.store_server_verify_keys(
+                "from_server",
+                0,
+                [
+                    ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                    ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
+                ],
+            )
         )
-        self.get_success(d)
 
-        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        res = self.get_success(d)
+        res = self.get_success(
+            store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+        )
         self.assertEqual(len(res.keys()), 2)
 
         res1 = res[("srv1", key_id_1)]
@@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         self.assertEqual(res2.valid_until_ts, 200)
 
         # we should be able to look up the same thing again without a db hit
-        res = store.get_server_verify_keys([("srv1", key_id_1)])
-        if isinstance(res, Deferred):
-            res = self.successResultOf(res)
+        res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
         self.assertEqual(len(res.keys()), 1)
         self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
 
@@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        res = self.get_success(d)
+        res = self.get_success(
+            store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
+        )
         self.assertEqual(len(res.keys()), 2)
 
         res1 = res[("srv1", key_id_1)]
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 010cc74c31..857e2caf2e 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -23,7 +23,6 @@ from tests.unittest import HomeserverTestCase
 
 
 class PurgeTests(HomeserverTestCase):
-
     user_id = "@red:server"
     servlets = [room.register_servlets]
 
@@ -112,7 +111,7 @@ class PurgeTests(HomeserverTestCase):
                 self.room_id, "m.room.create", ""
             )
         )
-        self.assertIsNotNone(create_event)
+        assert create_event is not None
 
         # Purge everything before this topological token
         self.get_success(
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index d8d84152dc..1b52eef23f 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
         self.store = homeserver.get_datastores().main
 
         self.room_creator = homeserver.get_room_creation_handler()
-        self.persist_event_storage_controller = (
-            self.hs.get_storage_controllers().persistence
-        )
+        persist_event_storage_controller = self.hs.get_storage_controllers().persistence
+        assert persist_event_storage_controller is not None
+        self.persist_event_storage_controller = persist_event_storage_controller
 
         # Create a test user
         self.ourUser = UserID.from_string(OUR_USER_ID)
@@ -50,12 +50,14 @@ class ReceiptTestCase(HomeserverTestCase):
         self.otherRequester = create_requester(self.otherUser)
 
         # Create a test room
-        info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
-        self.room_id1 = info["room_id"]
+        self.room_id1, _, _ = self.get_success(
+            self.room_creator.create_room(self.ourRequester, {})
+        )
 
         # Create a second test room
-        info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
-        self.room_id2 = info["room_id"]
+        self.room_id2, _, _ = self.get_success(
+            self.room_creator.create_room(self.ourRequester, {})
+        )
 
         # Join the second user to the first room
         memberEvent, memberEventContext = self.get_success(
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             def internal_metadata(self) -> _EventInternalMetadata:
                 return self._base_builder.internal_metadata
 
-        event_1, context_1 = self.get_success(
+        event_1, unpersisted_context_1 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
+        context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
         self.get_success(self._persistence.persist_event(event_1, context_1))
 
-        event_2, context_2 = self.get_success(
+        event_2, unpersisted_context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
+
+        context_2 = self.get_success(unpersisted_context_2.persist(event_2))
         self.get_success(self._persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        redaction_event, context = self.get_success(
+        redaction_event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(redaction_event))
+
         self.get_success(self._persistence.persist_event(redaction_event, context))
 
         # Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 14d872514d..f183c38477 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
             "content": {"msgtype": "m.text", "body": 2},
             "room_id": room_id,
             "sender": user_id,
-            "depth": prev_event.depth + 1,
             "prev_events": prev_event_ids,
             "origin_server_ts": self.clock.time_msec(),
         }
@@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
                     prev_state_map,
                     for_verification=False,
                 ),
-                depth=event_dict["depth"],
+                depth=prev_event.depth + 1,
             )
         )
 
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 8794401823..f4c4661aaf 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -27,7 +27,6 @@ from tests.test_utils import event_injection
 
 
 class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
-
     servlets = [
         login.register_servlets,
         register_servlets_for_client_rest_resource,
@@ -35,7 +34,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: TestHomeServer) -> None:  # type: ignore[override]
-
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
         self.store = hs.get_datastores().main
@@ -48,7 +46,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         self.u_charlie = UserID.from_string("@charlie:elsewhere")
 
     def test_one_member(self) -> None:
-
         # Alice creates the room, and is automatically joined
         self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
 
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..62aed6af0a 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         assert self.storage.persistence is not None
         self.get_success(self.storage.persistence.persist_event(event, context))
 
@@ -240,7 +242,7 @@ class StateStoreTestCase(HomeserverTestCase):
 
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -257,7 +259,7 @@ class StateStoreTestCase(HomeserverTestCase):
             state_dict,
         )
 
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -270,7 +272,7 @@ class StateStoreTestCase(HomeserverTestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with wildcard types
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -287,7 +289,7 @@ class StateStoreTestCase(HomeserverTestCase):
             state_dict,
         )
 
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -307,7 +309,7 @@ class StateStoreTestCase(HomeserverTestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -325,7 +327,7 @@ class StateStoreTestCase(HomeserverTestCase):
             state_dict,
         )
 
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -339,7 +341,7 @@ class StateStoreTestCase(HomeserverTestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -390,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         room_id = self.room.to_string()
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -402,7 +404,7 @@ class StateStoreTestCase(HomeserverTestCase):
         self.assertDictEqual({}, state_dict)
 
         room_id = self.room.to_string()
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -415,7 +417,7 @@ class StateStoreTestCase(HomeserverTestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # wildcard types
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -426,7 +428,7 @@ class StateStoreTestCase(HomeserverTestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -445,7 +447,7 @@ class StateStoreTestCase(HomeserverTestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -457,7 +459,7 @@ class StateStoreTestCase(HomeserverTestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -471,7 +473,7 @@ class StateStoreTestCase(HomeserverTestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -483,7 +485,7 @@ class StateStoreTestCase(HomeserverTestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
+        state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -494,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+    def test_batched_state_group_storing(self) -> None:
+        creation_event = self.inject_state_event(
+            self.room, self.u_alice, EventTypes.Create, "", {}
+        )
+        state_to_event = self.get_success(
+            self.storage.state.get_state_groups(
+                self.room.to_string(), [creation_event.event_id]
+            )
+        )
+        current_state_group = list(state_to_event.keys())[0]
+
+        # create some unpersisted events and event contexts to store against room
+        events_and_context = []
+        builder = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.Name,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"name": "first rename of room"},
+            },
+        )
+
+        event1, unpersisted_context1 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder)
+        )
+        events_and_context.append((event1, unpersisted_context1))
+
+        builder2 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.JoinRules,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"join_rule": "private"},
+            },
+        )
+
+        event2, unpersisted_context2 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder2)
+        )
+        events_and_context.append((event2, unpersisted_context2))
+
+        builder3 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.Message,
+                "sender": self.u_alice.to_string(),
+                "room_id": self.room.to_string(),
+                "content": {"body": "hello from event 3", "msgtype": "m.text"},
+            },
+        )
+
+        event3, unpersisted_context3 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder3)
+        )
+        events_and_context.append((event3, unpersisted_context3))
+
+        builder4 = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.JoinRules,
+                "sender": self.u_alice.to_string(),
+                "state_key": "",
+                "room_id": self.room.to_string(),
+                "content": {"join_rule": "public"},
+            },
+        )
+
+        event4, unpersisted_context4 = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder4)
+        )
+        events_and_context.append((event4, unpersisted_context4))
+
+        processed_events_and_context = self.get_success(
+            self.hs.get_datastores().state.store_state_deltas_for_batched(
+                events_and_context, self.room.to_string(), current_state_group
+            )
+        )
+
+        # check that only state events are in state_groups, and all state events are in state_groups
+        res = self.get_success(
+            self.store.db_pool.simple_select_list(
+                table="state_groups",
+                keyvalues=None,
+                retcols=("event_id",),
+            )
+        )
+
+        events = []
+        for result in res:
+            self.assertNotIn(event3.event_id, result)
+            events.append(result.get("event_id"))
+
+        for event, _ in processed_events_and_context:
+            if event.is_state():
+                self.assertIn(event.event_id, events)
+
+        # check that each unique state has state group in state_groups_state and that the
+        # type/state key is correct, and check that each state event's state group
+        # has an entry and prev event in state_group_edges
+        for event, context in processed_events_and_context:
+            if event.is_state():
+                state = self.get_success(
+                    self.store.db_pool.simple_select_list(
+                        table="state_groups_state",
+                        keyvalues={"state_group": context.state_group_after_event},
+                        retcols=("type", "state_key"),
+                    )
+                )
+                self.assertEqual(event.type, state[0].get("type"))
+                self.assertEqual(event.state_key, state[0].get("state_key"))
+
+                groups = self.get_success(
+                    self.store.db_pool.simple_select_list(
+                        table="state_group_edges",
+                        keyvalues={"state_group": str(context.state_group_after_event)},
+                        retcols=("*",),
+                    )
+                )
+                self.assertEqual(
+                    context.state_group_before_event, groups[0].get("prev_state_group")
+                )
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index bc090ebce0..05dc4f64b8 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,7 @@ from typing import List
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
 from synapse.api.filtering import Filter
 from synapse.rest import admin
 from synapse.rest.client import login, room
@@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
                 room_id=self.room_id,
                 from_key=self.from_token.room_key,
                 to_key=None,
-                direction="f",
+                direction=Direction.FORWARDS,
                 limit=10,
                 event_filter=Filter(self.hs, filter),
             )
diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py
index ba53c22818..19da8a9b09 100644
--- a/tests/storage/test_unsafe_locale.py
+++ b/tests/storage/test_unsafe_locale.py
@@ -14,6 +14,7 @@
 from unittest.mock import MagicMock, patch
 
 from synapse.storage.database import make_conn
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IncorrectDatabaseSetup
 
 from tests.unittest import HomeserverTestCase
@@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
 
     def test_safe_locale(self) -> None:
         database = self.hs.get_datastores().databases[0]
+        assert isinstance(database.engine, PostgresEngine)
 
         db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
         with db_conn.cursor() as txn:
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index f1ca523d23..8c72aa1722 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -25,6 +25,11 @@ from synapse.rest.client import login, register, room
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.background_updates import _BackgroundUpdateHandler
+from synapse.storage.databases.main import user_directory
+from synapse.storage.databases.main.user_directory import (
+    _parse_words_with_icu,
+    _parse_words_with_regex,
+)
 from synapse.storage.roommember import ProfileInfo
 from synapse.util import Clock
 
@@ -42,7 +47,7 @@ ALICE = "@alice:a"
 BOB = "@bob:b"
 BOBBY = "@bobby:a"
 # The localpart isn't 'Bela' on purpose so we can test looking up display names.
-BELA = "@somenickname:a"
+BELA = "@somenickname:example.org"
 
 
 class GetUserDirectoryTables:
@@ -423,6 +428,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
 
 
 class UserDirectoryStoreTestCase(HomeserverTestCase):
+    use_icu = False
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
@@ -434,6 +441,12 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
         self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
         self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
 
+        self._restore_use_icu = user_directory.USE_ICU
+        user_directory.USE_ICU = self.use_icu
+
+    def tearDown(self) -> None:
+        user_directory.USE_ICU = self._restore_use_icu
+
     def test_search_user_dir(self) -> None:
         # normally when alice searches the directory she should just find
         # bob because bobby doesn't share a room with her.
@@ -478,6 +491,159 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
             {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
         )
 
+    @override_config({"user_directory": {"search_all_users": True}})
+    def test_search_user_dir_start_of_user_id(self) -> None:
+        """Tests that a user can look up another user by searching for the start
+        of their user ID.
+        """
+        r = self.get_success(self.store.search_user_dir(ALICE, "somenickname:exa", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+        )
+
+    @override_config({"user_directory": {"search_all_users": True}})
+    def test_search_user_dir_ascii_case_insensitivity(self) -> None:
+        """Tests that a user can look up another user by searching for their name in a
+        different case.
+        """
+        CHARLIE = "@someuser:example.org"
+        self.get_success(
+            self.store.update_profile_in_user_dir(CHARLIE, "Charlie", None)
+        )
+
+        r = self.get_success(self.store.search_user_dir(ALICE, "cHARLIE", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": CHARLIE, "display_name": "Charlie", "avatar_url": None},
+        )
+
+    @override_config({"user_directory": {"search_all_users": True}})
+    def test_search_user_dir_unicode_case_insensitivity(self) -> None:
+        """Tests that a user can look up another user by searching for their name in a
+        different case.
+        """
+        IVAN = "@someuser:example.org"
+        self.get_success(self.store.update_profile_in_user_dir(IVAN, "Иван", None))
+
+        r = self.get_success(self.store.search_user_dir(ALICE, "иВАН", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": IVAN, "display_name": "Иван", "avatar_url": None},
+        )
+
+    @override_config({"user_directory": {"search_all_users": True}})
+    def test_search_user_dir_dotted_dotless_i_case_insensitivity(self) -> None:
+        """Tests that a user can look up another user by searching for their name in a
+        different case, when their name contains dotted or dotless "i"s.
+
+        Some languages have dotted and dotless versions of "i", which are considered to
+        be different letters: i <-> İ, ı <-> I. To make things difficult, they reuse the
+        ASCII "i" and "I" code points, despite having different lowercase / uppercase
+        forms.
+        """
+        USER = "@someuser:example.org"
+
+        expected_matches = [
+            # (search_term, display_name)
+            # A search for "i" should match "İ".
+            ("iiiii", "İİİİİ"),
+            # A search for "I" should match "ı".
+            ("IIIII", "ııııı"),
+            # A search for "ı" should match "I".
+            ("ııııı", "IIIII"),
+            # A search for "İ" should match "i".
+            ("İİİİİ", "iiiii"),
+        ]
+
+        for search_term, display_name in expected_matches:
+            self.get_success(
+                self.store.update_profile_in_user_dir(USER, display_name, None)
+            )
+
+            r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10))
+            self.assertFalse(r["limited"])
+            self.assertEqual(
+                1,
+                len(r["results"]),
+                f"searching for {search_term!r} did not match {display_name!r}",
+            )
+            self.assertDictEqual(
+                r["results"][0],
+                {"user_id": USER, "display_name": display_name, "avatar_url": None},
+            )
+
+        # We don't test for negative matches, to allow implementations that consider all
+        # the i variants to be the same.
+
+    test_search_user_dir_dotted_dotless_i_case_insensitivity.skip = "not supported"  # type: ignore
+
+    @override_config({"user_directory": {"search_all_users": True}})
+    def test_search_user_dir_unicode_normalization(self) -> None:
+        """Tests that a user can look up another user by searching for their name with
+        either composed or decomposed accents.
+        """
+        AMELIE = "@someuser:example.org"
+
+        expected_matches = [
+            # (search_term, display_name)
+            ("Ame\u0301lie", "Amélie"),
+            ("Amélie", "Ame\u0301lie"),
+        ]
+
+        for search_term, display_name in expected_matches:
+            self.get_success(
+                self.store.update_profile_in_user_dir(AMELIE, display_name, None)
+            )
+
+            r = self.get_success(self.store.search_user_dir(ALICE, search_term, 10))
+            self.assertFalse(r["limited"])
+            self.assertEqual(
+                1,
+                len(r["results"]),
+                f"searching for {search_term!r} did not match {display_name!r}",
+            )
+            self.assertDictEqual(
+                r["results"][0],
+                {"user_id": AMELIE, "display_name": display_name, "avatar_url": None},
+            )
+
+    @override_config({"user_directory": {"search_all_users": True}})
+    def test_search_user_dir_accent_insensitivity(self) -> None:
+        """Tests that a user can look up another user by searching for their name
+        without any accents.
+        """
+        AMELIE = "@someuser:example.org"
+        self.get_success(self.store.update_profile_in_user_dir(AMELIE, "Amélie", None))
+
+        r = self.get_success(self.store.search_user_dir(ALICE, "amelie", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": AMELIE, "display_name": "Amélie", "avatar_url": None},
+        )
+
+        # It may be desirable for "é"s in search terms to not match plain "e"s and we
+        # really don't want "é"s in search terms to match "e"s with different accents.
+        # But we don't test for this to allow implementations that consider all
+        # "e"-lookalikes to be the same.
+
+    test_search_user_dir_accent_insensitivity.skip = "not supported yet"  # type: ignore
+
+
+class UserDirectoryStoreTestCaseWithIcu(UserDirectoryStoreTestCase):
+    use_icu = True
+
+    if not icu:
+        skip = "Requires PyICU"
+
 
 class UserDirectoryICUTestCase(HomeserverTestCase):
     if not icu:
@@ -513,3 +679,33 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
             r["results"][0],
             {"user_id": ALICE, "display_name": display_name, "avatar_url": None},
         )
+
+    def test_icu_word_boundary_punctuation(self) -> None:
+        """
+        Tests the behaviour of punctuation with the ICU tokeniser.
+
+        Seems to depend on underlying version of ICU.
+        """
+
+        # Note: either tokenisation is fine, because Postgres actually splits
+        # words itself afterwards.
+        self.assertIn(
+            _parse_words_with_icu("lazy'fox jumped:over the.dog"),
+            (
+                # ICU 66 on Ubuntu 20.04
+                ["lazy'fox", "jumped", "over", "the", "dog"],
+                # ICU 70 on Ubuntu 22.04
+                ["lazy'fox", "jumped:over", "the.dog"],
+                # pyicu 2.10.2 on Alpine edge / macOS
+                ["lazy'fox", "jumped", "over", "the.dog"],
+            ),
+        )
+
+    def test_regex_word_boundary_punctuation(self) -> None:
+        """
+        Tests the behaviour of punctuation with the non-ICU tokeniser
+        """
+        self.assertEqual(
+            _parse_words_with_regex("lazy'fox jumped:over the.dog"),
+            ["lazy", "fox", "jumped", "over", "the", "dog"],
+        )
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest
 
 
 class DistributorTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.dist = Distributor()
 
-    def test_signal_dispatch(self):
+    def test_signal_dispatch(self) -> None:
         self.dist.declare("alert")
 
         observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
         self.dist.fire("alert", 1, 2, 3)
         observer.assert_called_with(1, 2, 3)
 
-    def test_signal_catch(self):
+    def test_signal_catch(self) -> None:
         self.dist.declare("alarm")
 
         observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
             self.assertEqual(mock_logger.warning.call_count, 1)
             self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
 
-    def test_signal_prereg(self):
+    def test_signal_prereg(self) -> None:
         observer = Mock()
         self.dist.observe("flare", observer)
 
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
 
         observer.assert_called_with(4, 5)
 
-    def test_signal_undeclared(self):
-        def code():
+    def test_signal_undeclared(self) -> None:
+        def code() -> None:
             self.dist.fire("notification")
 
         self.assertRaises(KeyError, code)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
 class _StubEventSourceStore:
     """A stub implementation of the EventSourceStore"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._store: Dict[str, EventBase] = {}
 
-    def add_event(self, event: EventBase):
+    def add_event(self, event: EventBase) -> None:
         self._store[event.event_id] = event
 
-    def add_events(self, events: Iterable[EventBase]):
+    def add_events(self, events: Iterable[EventBase]) -> None:
         for event in events:
             self._store[event.event_id] = event
 
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
 
 
 class EventAuthTestCase(unittest.TestCase):
-    def test_rejected_auth_events(self):
+    def test_rejected_auth_events(self) -> None:
         """
         Events that refer to rejected events in their auth events are rejected
         """
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
                 )
             )
 
-    def test_create_event_with_prev_events(self):
+    def test_create_event_with_prev_events(self) -> None:
         """A create event with prev_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event)
             )
 
-    def test_duplicate_auth_events(self):
+    def test_duplicate_auth_events(self) -> None:
         """Events with duplicate auth_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event2)
             )
 
-    def test_unexpected_auth_events(self):
+    def test_unexpected_auth_events(self) -> None:
         """Events with excess auth_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event)
             )
 
-    def test_random_users_cannot_send_state_before_first_pl(self):
+    def test_random_users_cannot_send_state_before_first_pl(self) -> None:
         """
         Check that, before the first PL lands, the creator is the only user
         that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_state_default_level(self):
+    def test_state_default_level(self) -> None:
         """
         Check that users above the state_default level can send state and
         those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_alias_event(self):
+    def test_alias_event(self) -> None:
         """Alias events have special behavior up through room version 6."""
         creator = "@creator:example.com"
         other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_msc2432_alias_event(self):
+    def test_msc2432_alias_event(self) -> None:
         """After MSC2432, alias events have no special behavior."""
         creator = "@creator:example.com"
         other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
             )
 
     @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
-    def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+    def test_notifications(
+        self, room_version: RoomVersion, allow_modification: bool
+    ) -> None:
         """
         Notifications power levels get checked due to MSC2209.
         """
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
             with self.assertRaises(AuthError):
                 event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
 
-    def test_join_rules_public(self):
+    def test_join_rules_public(self) -> None:
         """
         Test joining a public room.
         """
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events.values(),
         )
 
-    def test_join_rules_invite(self):
+    def test_join_rules_invite(self) -> None:
         """
         Test joining an invite only room.
         """
@@ -835,7 +837,7 @@ def _power_levels_event(
     )
 
 
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
     data = {
         "room_id": TEST_ROOM_ID,
         **_maybe_get_event_id_dict_for_room_version(room_version),
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 80e5c590d8..46d2f99eac 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,53 +12,48 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Collection, List, Optional, Union
 from unittest.mock import Mock
 
-from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import FederationError
-from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import event_from_pdu_json
+from synapse.handlers.device import DeviceListUpdater
+from synapse.http.types import QueryParams
 from synapse.logging.context import LoggingContext
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
 from tests import unittest
-from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
 from tests.test_utils import make_awaitable
 
 
 class MessageAcceptTests(unittest.HomeserverTestCase):
-    def setUp(self):
-
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock()
-        self.reactor = ThreadedMemoryReactorClock()
-        self.hs_clock = Clock(self.reactor)
-        self.homeserver = setup_test_homeserver(
-            self.addCleanup,
-            federation_http_client=self.http_client,
-            clock=self.hs_clock,
-            reactor=self.reactor,
-        )
+        return self.setup_test_homeserver(federation_http_client=self.http_client)
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         user_id = UserID("us", "test")
         our_user = create_requester(user_id)
-        room_creator = self.homeserver.get_room_creation_handler()
+        room_creator = self.hs.get_room_creation_handler()
         self.room_id = self.get_success(
             room_creator.create_room(
                 our_user, room_creator._presets_dict["public_chat"], ratelimit=False
             )
-        )[0]["room_id"]
+        )[0]
 
-        self.store = self.homeserver.get_datastores().main
+        self.store = self.hs.get_datastores().main
 
         # Figure out what the most recent event is
         most_recent = self.get_success(
-            self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
-                self.room_id
-            )
+            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )[0]
 
         join_event = make_event_from_dict(
@@ -78,17 +73,23 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        self.handler = self.homeserver.get_federation_handler()
-        federation_event_handler = self.homeserver.get_federation_event_handler()
+        self.handler = self.hs.get_federation_handler()
+        federation_event_handler = self.hs.get_federation_event_handler()
 
-        async def _check_event_auth(origin, event, context):
+        async def _check_event_auth(
+            origin: Optional[str], event: EventBase, context: EventContext
+        ) -> None:
             pass
 
-        federation_event_handler._check_event_auth = _check_event_auth
-        self.client = self.homeserver.get_federation_client()
-        self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
-            lambda dest, pdus, **k: succeed(pdus)
-        )
+        federation_event_handler._check_event_auth = _check_event_auth  # type: ignore[assignment]
+        self.client = self.hs.get_federation_client()
+
+        async def _check_sigs_and_hash_for_pulled_events_and_fetch(
+            dest: str, pdus: Collection[EventBase], room_version: RoomVersion
+        ) -> List[EventBase]:
+            return list(pdus)
+
+        self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch  # type: ignore[assignment]
 
         # Send the join, it should return None (which is not an error)
         self.assertEqual(
@@ -104,16 +105,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             "$join:test.serv",
         )
 
-    def test_cant_hide_direct_ancestors(self):
+    def test_cant_hide_direct_ancestors(self) -> None:
         """
         If you send a message, you must be able to provide the direct
         prev_events that said event references.
         """
 
-        async def post_json(destination, path, data, headers=None, timeout=0):
+        async def post_json(
+            destination: str,
+            path: str,
+            data: Optional[JsonDict] = None,
+            long_retries: bool = False,
+            timeout: Optional[int] = None,
+            ignore_backoff: bool = False,
+            args: Optional[QueryParams] = None,
+        ) -> Union[JsonDict, list]:
             # If it asks us for new missing events, give them NOTHING
             if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                 return {"events": []}
+            return {}
 
         self.http_client.post_json = post_json
 
@@ -138,7 +148,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        federation_event_handler = self.homeserver.get_federation_event_handler()
+        federation_event_handler = self.hs.get_federation_event_handler()
         with LoggingContext("test-context"):
             failure = self.get_failure(
                 federation_event_handler.on_receive_pdu("test.serv", lying_event),
@@ -158,7 +168,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
         self.assertEqual(extrem[0], "$join:test.serv")
 
-    def test_retry_device_list_resync(self):
+    def test_retry_device_list_resync(self) -> None:
         """Tests that device lists are marked as stale if they couldn't be synced, and
         that stale device lists are retried periodically.
         """
@@ -171,24 +181,27 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # When this function is called, increment the number of resync attempts (only if
         # we're querying devices for the right user ID), then raise a
         # NotRetryingDestination error to fail the resync gracefully.
-        def query_user_devices(destination, user_id):
+        def query_user_devices(
+            destination: str, user_id: str, timeout: int = 30000
+        ) -> JsonDict:
             if user_id == remote_user_id:
                 self.resync_attempts += 1
 
             raise NotRetryingDestination(0, 0, destination)
 
         # Register the mock on the federation client.
-        federation_client = self.homeserver.get_federation_client()
-        federation_client.query_user_devices = Mock(side_effect=query_user_devices)
+        federation_client = self.hs.get_federation_client()
+        federation_client.query_user_devices = Mock(side_effect=query_user_devices)  # type: ignore[assignment]
 
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
-        store = self.homeserver.get_datastores().main
+        store = self.hs.get_datastores().main
         store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
 
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
-        device_list_updater = self.homeserver.get_device_handler().device_list_updater
+        device_list_updater = self.hs.get_device_handler().device_list_updater
+        assert isinstance(device_list_updater, DeviceListUpdater)
         self.get_success(
             device_list_updater.incoming_device_list_update(
                 origin=remote_origin,
@@ -218,7 +231,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.reactor.advance(30)
         self.assertEqual(self.resync_attempts, 2)
 
-    def test_cross_signing_keys_retry(self):
+    def test_cross_signing_keys_retry(self) -> None:
         """Tests that resyncing a device list correctly processes cross-signing keys from
         the remote server.
         """
@@ -227,8 +240,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
         # Register mock device list retrieval on the federation client.
-        federation_client = self.homeserver.get_federation_client()
-        federation_client.query_user_devices = Mock(
+        federation_client = self.hs.get_federation_client()
+        federation_client.query_user_devices = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(
                 {
                     "user_id": remote_user_id,
@@ -252,7 +265,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         # Resync the device list.
-        device_handler = self.homeserver.get_device_handler()
+        device_handler = self.hs.get_device_handler()
         self.get_success(
             device_handler.device_list_updater.user_device_resync(remote_user_id),
         )
@@ -261,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         keys = self.get_success(
             self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
         )
-        self.assertTrue(remote_user_id in keys)
+        self.assertIn(remote_user_id, keys)
+        key = keys[remote_user_id]
+        assert key is not None
 
         # Check that the master key is the one returned by the mock.
-        master_key = keys[remote_user_id]["master"]
+        master_key = key["master"]
         self.assertEqual(len(master_key["keys"]), 1)
         self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
         self.assertTrue(remote_master_key in master_key["keys"].values())
 
         # Check that the self-signing key is the one returned by the mock.
-        self_signing_key = keys[remote_user_id]["self_signing"]
+        self_signing_key = key["self_signing"]
         self.assertEqual(len(self_signing_key["keys"]), 1)
         self.assertTrue(
             "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
@@ -279,7 +294,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
 
 class StripUnsignedFromEventsTestCase(unittest.TestCase):
-    def test_strip_unauthorized_unsigned_values(self):
+    def test_strip_unauthorized_unsigned_values(self) -> None:
         event1 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
@@ -296,7 +311,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
         # Make sure unauthorized fields are stripped from unsigned
         self.assertNotIn("more warez", filtered_event.unsigned)
 
-    def test_strip_event_maintains_allowed_fields(self):
+    def test_strip_event_maintains_allowed_fields(self) -> None:
         event2 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
@@ -323,7 +338,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
         self.assertIn("invite_room_state", filtered_event2.unsigned)
         self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
 
-    def test_strip_event_removes_fields_based_on_event_type(self):
+    def test_strip_event_removes_fields_based_on_event_type(self) -> None:
         event3 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..ff21098a59 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -14,12 +14,17 @@
 
 """Tests REST events for /rooms paths."""
 
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
@@ -27,10 +32,9 @@ from tests.utils import default_config
 
 
 class TestMauLimit(unittest.HomeserverTestCase):
-
     servlets = [register.register_servlets, sync.register_servlets]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = default_config("test")
 
         config.update(
@@ -53,10 +57,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.store = homeserver.get_datastores().main
 
-    def test_simple_deny_mau(self):
+    def test_simple_deny_mau(self) -> None:
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -75,7 +81,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.code, 403)
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
-    def test_as_ignores_mau(self):
+    def test_as_ignores_mau(self) -> None:
         """Test that application services can still create users when the MAU
         limit has been reached. This only works when application service
         user ip tracking is disabled.
@@ -113,7 +119,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         self.create_user("as_kermit4", token=as_token, appservice=True)
 
-    def test_allowed_after_a_month_mau(self):
+    def test_allowed_after_a_month_mau(self) -> None:
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -132,7 +138,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.do_sync_for_user(token3)
 
     @override_config({"mau_trial_days": 1})
-    def test_trial_delay(self):
+    def test_trial_delay(self) -> None:
         # We should be able to register more than the limit initially
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -165,7 +171,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     @override_config({"mau_trial_days": 1})
-    def test_trial_users_cant_come_back(self):
+    def test_trial_users_cant_come_back(self) -> None:
         self.hs.config.server.mau_trial_days = 1
 
         # We should be able to register more than the limit initially
@@ -216,7 +222,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         # max_mau_value should not matter
         {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
     )
-    def test_tracked_but_not_limited(self):
+    def test_tracked_but_not_limited(self) -> None:
         # Simply being able to create 2 users indicates that the
         # limit was not reached.
         token1 = self.create_user("kermit1")
@@ -236,10 +242,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
             "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
         }
     )
-    def test_as_trial_days(self):
+    def test_as_trial_days(self) -> None:
         user_tokens: List[str] = []
 
-        def advance_time_and_sync():
+        def advance_time_and_sync() -> None:
             self.reactor.advance(24 * 60 * 61)
             for token in user_tokens:
                 self.do_sync_for_user(token)
@@ -300,7 +306,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
             },
         )
 
-    def create_user(self, localpart, token=None, appservice=False):
+    def create_user(
+        self, localpart: str, token: Optional[str] = None, appservice: bool = False
+    ) -> str:
         request_data = {
             "username": localpart,
             "password": "monkey",
@@ -326,7 +334,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         return access_token
 
-    def do_sync_for_user(self, token):
+    def do_sync_for_user(self, token: str) -> None:
         channel = self.make_request("GET", "/sync", access_token=token)
 
         if channel.code != 200:
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index cc1a98f1c4..3f899b0d91 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
         If time doesn't move, don't error out.
         """
         past_stats = [
-            (self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
+            (int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
         ]
         stats: JsonDict = {}
         self.get_success(phone_stats_home(self.hs, stats, past_stats))
diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644
--- a/tests/test_rust.py
+++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest
 class RustTestCase(unittest.TestCase):
     """Basic tests to ensure that we can call into Rust code."""
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         result = sum_as_string(1, 2)
         self.assertEqual("3", result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 504530b49a..b20a26e1ff 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -11,7 +11,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Collection, Dict, List, Optional, cast
+from typing import (
+    Any,
+    Collection,
+    Dict,
+    Generator,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 from unittest.mock import Mock
 
 from twisted.internet import defer
@@ -19,9 +31,11 @@ from twisted.internet import defer
 from synapse.api.auth import Auth
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
+from synapse.types import MutableStateMap, StateMap
+from synapse.types.state import StateFilter
 from synapse.util import Clock
 from synapse.util.macaroons import MacaroonGenerator
 
@@ -33,14 +47,14 @@ _next_event_id = 1000
 
 
 def create_event(
-    name=None,
-    type=None,
-    state_key=None,
-    depth=2,
-    event_id=None,
-    prev_events: Optional[List[str]] = None,
-    **kwargs,
-):
+    name: Optional[str] = None,
+    type: Optional[str] = None,
+    state_key: Optional[str] = None,
+    depth: int = 2,
+    event_id: Optional[str] = None,
+    prev_events: Optional[List[Tuple[str, dict]]] = None,
+    **kwargs: Any,
+) -> EventBase:
     global _next_event_id
 
     if not event_id:
@@ -67,21 +81,21 @@ def create_event(
 
     d.update(kwargs)
 
-    event = make_event_from_dict(d)
-
-    return event
+    return make_event_from_dict(d)
 
 
 class _DummyStore:
-    def __init__(self):
-        self._event_to_state_group = {}
-        self._group_to_state = {}
+    def __init__(self) -> None:
+        self._event_to_state_group: Dict[str, int] = {}
+        self._group_to_state: Dict[int, MutableStateMap[str]] = {}
 
-        self._event_id_to_event = {}
+        self._event_id_to_event: Dict[str, EventBase] = {}
 
         self._next_group = 1
 
-    async def get_state_groups_ids(self, room_id, event_ids):
+    async def get_state_groups_ids(
+        self, room_id: str, event_ids: Collection[str]
+    ) -> Dict[int, MutableStateMap[str]]:
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
@@ -90,16 +104,25 @@ class _DummyStore:
 
         return groups
 
-    async def get_state_ids_for_group(self, state_group, state_filter=None):
+    async def get_state_ids_for_group(
+        self, state_group: int, state_filter: Optional[StateFilter] = None
+    ) -> MutableStateMap[str]:
         return self._group_to_state[state_group]
 
     async def store_state_group(
-        self, event_id, room_id, prev_group, delta_ids, current_state_ids
-    ):
+        self,
+        event_id: str,
+        room_id: str,
+        prev_group: Optional[int],
+        delta_ids: Optional[StateMap[str]],
+        current_state_ids: Optional[StateMap[str]],
+    ) -> int:
         state_group = self._next_group
         self._next_group += 1
 
         if current_state_ids is None:
+            assert prev_group is not None
+            assert delta_ids is not None
             current_state_ids = dict(self._group_to_state[prev_group])
             current_state_ids.update(delta_ids)
 
@@ -107,7 +130,9 @@ class _DummyStore:
 
         return state_group
 
-    async def get_events(self, event_ids, **kwargs):
+    async def get_events(
+        self, event_ids: Collection[str], **kwargs: Any
+    ) -> Dict[str, EventBase]:
         return {
             e_id: self._event_id_to_event[e_id]
             for e_id in event_ids
@@ -119,31 +144,36 @@ class _DummyStore:
     ) -> Dict[str, bool]:
         return {e: False for e in event_ids}
 
-    async def get_state_group_delta(self, name):
+    async def get_state_group_delta(
+        self, name: str
+    ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
         return None, None
 
-    def register_events(self, events):
+    def register_events(self, events: Iterable[EventBase]) -> None:
         for e in events:
             self._event_id_to_event[e.event_id] = e
 
-    def register_event_context(self, event, context):
+    def register_event_context(self, event: EventBase, context: EventContext) -> None:
+        assert context.state_group is not None
         self._event_to_state_group[event.event_id] = context.state_group
 
-    def register_event_id_state_group(self, event_id, state_group):
+    def register_event_id_state_group(self, event_id: str, state_group: int) -> None:
         self._event_to_state_group[event_id] = state_group
 
-    async def get_room_version_id(self, room_id):
+    async def get_room_version_id(self, room_id: str) -> str:
         return RoomVersions.V1.identifier
 
     async def get_state_group_for_events(
-        self, event_ids, await_full_state: bool = True
-    ):
+        self, event_ids: Collection[str], await_full_state: bool = True
+    ) -> Dict[str, int]:
         res = {}
         for event in event_ids:
             res[event] = self._event_to_state_group[event]
         return res
 
-    async def get_state_for_groups(self, groups):
+    async def get_state_for_groups(
+        self, groups: Collection[int]
+    ) -> Dict[int, MutableStateMap[str]]:
         res = {}
         for group in groups:
             state = self._group_to_state[group]
@@ -152,21 +182,21 @@ class _DummyStore:
 
 
 class DictObj(dict):
-    def __init__(self, **kwargs):
+    def __init__(self, **kwargs: Any) -> None:
         super().__init__(kwargs)
         self.__dict__ = self
 
 
 class Graph:
-    def __init__(self, nodes, edges):
-        events = {}
-        clobbered = set(events.keys())
+    def __init__(self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]):
+        events: Dict[str, EventBase] = {}
+        clobbered: Set[str] = set()
 
         for event_id, fields in nodes.items():
             refs = edges.get(event_id)
             if refs:
                 clobbered.difference_update(refs)
-                prev_events = [(r, {}) for r in refs]
+                prev_events: List[Tuple[str, dict]] = [(r, {}) for r in refs]
             else:
                 prev_events = []
 
@@ -177,15 +207,12 @@ class Graph:
         self._leaves = clobbered
         self._events = sorted(events.values(), key=lambda e: e.depth)
 
-    def walk(self):
+    def walk(self) -> Iterator[EventBase]:
         return iter(self._events)
 
-    def get_leaves(self):
-        return (self._events[i] for i in self._leaves)
-
 
 class StateTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.dummy_store = _DummyStore()
         storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
         hs = Mock(
@@ -220,7 +247,7 @@ class StateTestCase(unittest.TestCase):
         self.event_id = 0
 
     @defer.inlineCallbacks
-    def test_branch_no_conflict(self):
+    def test_branch_no_conflict(self) -> Generator[defer.Deferred, Any, None]:
         graph = Graph(
             nodes={
                 "START": DictObj(
@@ -248,6 +275,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
+        prev_state_ids: StateMap[str]
         prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertEqual(2, len(prev_state_ids))
 
@@ -255,7 +283,9 @@ class StateTestCase(unittest.TestCase):
         self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
 
     @defer.inlineCallbacks
-    def test_branch_basic_conflict(self):
+    def test_branch_basic_conflict(
+        self,
+    ) -> Generator["defer.Deferred[object]", Any, None]:
         graph = Graph(
             nodes={
                 "START": DictObj(
@@ -280,7 +310,7 @@ class StateTestCase(unittest.TestCase):
 
         self.dummy_store.register_events(graph.walk())
 
-        context_store = {}
+        context_store: Dict[str, EventContext] = {}
 
         for event in graph.walk():
             context = yield defer.ensureDeferred(
@@ -294,6 +324,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
+        prev_state_ids: StateMap[str]
         prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
 
@@ -301,7 +332,9 @@ class StateTestCase(unittest.TestCase):
         self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
 
     @defer.inlineCallbacks
-    def test_branch_have_banned_conflict(self):
+    def test_branch_have_banned_conflict(
+        self,
+    ) -> Generator["defer.Deferred[object]", Any, None]:
         graph = Graph(
             nodes={
                 "START": DictObj(
@@ -338,7 +371,7 @@ class StateTestCase(unittest.TestCase):
 
         self.dummy_store.register_events(graph.walk())
 
-        context_store = {}
+        context_store: Dict[str, EventContext] = {}
 
         for event in graph.walk():
             context = yield defer.ensureDeferred(
@@ -353,13 +386,16 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_e = context_store["E"]
 
+        prev_state_ids: StateMap[str]
         prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
         self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
         self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
         self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
 
     @defer.inlineCallbacks
-    def test_branch_have_perms_conflict(self):
+    def test_branch_have_perms_conflict(
+        self,
+    ) -> Generator["defer.Deferred[object]", Any, None]:
         userid1 = "@user_id:example.com"
         userid2 = "@user_id2:example.com"
 
@@ -413,7 +449,7 @@ class StateTestCase(unittest.TestCase):
 
         self.dummy_store.register_events(graph.walk())
 
-        context_store = {}
+        context_store: Dict[str, EventContext] = {}
 
         for event in graph.walk():
             context = yield defer.ensureDeferred(
@@ -428,14 +464,17 @@ class StateTestCase(unittest.TestCase):
         ctx_b = context_store["B"]
         ctx_d = context_store["D"]
 
+        prev_state_ids: StateMap[str]
         prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
 
         self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
         self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
 
-    def _add_depths(self, nodes, edges):
-        def _get_depth(ev):
+    def _add_depths(
+        self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]
+    ) -> None:
+        def _get_depth(ev: str) -> int:
             node = nodes[ev]
             if "depth" not in node:
                 prevs = edges[ev]
@@ -447,7 +486,9 @@ class StateTestCase(unittest.TestCase):
             _get_depth(n)
 
     @defer.inlineCallbacks
-    def test_annotate_with_old_message(self):
+    def test_annotate_with_old_message(
+        self,
+    ) -> Generator["defer.Deferred[object]", Any, None]:
         event = create_event(type="test_message", name="event")
 
         old_state = [
@@ -456,6 +497,7 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
+        context: EventContext
         context = yield defer.ensureDeferred(
             self.state.compute_event_context(
                 event,
@@ -466,9 +508,11 @@ class StateTestCase(unittest.TestCase):
             )
         )
 
+        prev_state_ids: StateMap[str]
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
+        current_state_ids: StateMap[str]
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state), current_state_ids.values()
@@ -478,7 +522,9 @@ class StateTestCase(unittest.TestCase):
         self.assertEqual(context.state_group_before_event, context.state_group)
 
     @defer.inlineCallbacks
-    def test_annotate_with_old_state(self):
+    def test_annotate_with_old_state(
+        self,
+    ) -> Generator["defer.Deferred[object]", Any, None]:
         event = create_event(type="state", state_key="", name="event")
 
         old_state = [
@@ -487,6 +533,7 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
+        context: EventContext
         context = yield defer.ensureDeferred(
             self.state.compute_event_context(
                 event,
@@ -497,9 +544,11 @@ class StateTestCase(unittest.TestCase):
             )
         )
 
+        prev_state_ids: StateMap[str]
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
+        current_state_ids: StateMap[str]
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state + [event]), current_state_ids.values()
@@ -511,7 +560,9 @@ class StateTestCase(unittest.TestCase):
         self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
 
     @defer.inlineCallbacks
-    def test_trivial_annotate_message(self):
+    def test_trivial_annotate_message(
+        self,
+    ) -> Generator["defer.Deferred[object]", Any, None]:
         prev_event_id = "prev_event_id"
         event = create_event(
             type="test_message", name="event2", prev_events=[(prev_event_id, {})]
@@ -534,8 +585,10 @@ class StateTestCase(unittest.TestCase):
         )
         self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
+        context: EventContext
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
+        current_state_ids: StateMap[str]
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(
@@ -545,7 +598,9 @@ class StateTestCase(unittest.TestCase):
         self.assertEqual(group_name, context.state_group)
 
     @defer.inlineCallbacks
-    def test_trivial_annotate_state(self):
+    def test_trivial_annotate_state(
+        self,
+    ) -> Generator["defer.Deferred[object]", Any, None]:
         prev_event_id = "prev_event_id"
         event = create_event(
             type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
@@ -568,8 +623,10 @@ class StateTestCase(unittest.TestCase):
         )
         self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
 
+        context: EventContext
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
+        prev_state_ids: StateMap[str]
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
 
         self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
@@ -577,7 +634,9 @@ class StateTestCase(unittest.TestCase):
         self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
-    def test_resolve_message_conflict(self):
+    def test_resolve_message_conflict(
+        self,
+    ) -> Generator["defer.Deferred[Any]", Any, None]:
         prev_event_id1 = "event_id1"
         prev_event_id2 = "event_id2"
         event = create_event(
@@ -605,10 +664,12 @@ class StateTestCase(unittest.TestCase):
         self.dummy_store.register_events(old_state_1)
         self.dummy_store.register_events(old_state_2)
 
+        context: EventContext
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
+        current_state_ids: StateMap[str]
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
@@ -616,7 +677,9 @@ class StateTestCase(unittest.TestCase):
         self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
-    def test_resolve_state_conflict(self):
+    def test_resolve_state_conflict(
+        self,
+    ) -> Generator["defer.Deferred[Any]", Any, None]:
         prev_event_id1 = "event_id1"
         prev_event_id2 = "event_id2"
         event = create_event(
@@ -645,12 +708,14 @@ class StateTestCase(unittest.TestCase):
         store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.dummy_store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events  # type: ignore[assignment]
 
+        context: EventContext
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
+        current_state_ids: StateMap[str]
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
@@ -658,7 +723,9 @@ class StateTestCase(unittest.TestCase):
         self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
-    def test_standard_depth_conflict(self):
+    def test_standard_depth_conflict(
+        self,
+    ) -> Generator["defer.Deferred[Any]", Any, None]:
         prev_event_id1 = "event_id1"
         prev_event_id2 = "event_id2"
         event = create_event(
@@ -700,12 +767,14 @@ class StateTestCase(unittest.TestCase):
         store = _DummyStore()
         store.register_events(old_state_1)
         store.register_events(old_state_2)
-        self.dummy_store.get_events = store.get_events
+        self.dummy_store.get_events = store.get_events  # type: ignore[assignment]
 
+        context: EventContext
         context = yield self._get_context(
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
+        current_state_ids: StateMap[str]
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -740,8 +809,14 @@ class StateTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def _get_context(
-        self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
-    ):
+        self,
+        event: EventBase,
+        prev_event_id_1: str,
+        old_state_1: Collection[EventBase],
+        prev_event_id_2: str,
+        old_state_2: Collection[EventBase],
+    ) -> Generator["defer.Deferred[object]", Any, EventContext]:
+        sg1: int
         sg1 = yield defer.ensureDeferred(
             self.dummy_store.store_state_group(
                 prev_event_id_1,
@@ -753,6 +828,7 @@ class StateTestCase(unittest.TestCase):
         )
         self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
 
+        sg2: int
         sg2 = yield defer.ensureDeferred(
             self.dummy_store.store_state_group(
                 prev_event_id_2,
@@ -767,7 +843,7 @@ class StateTestCase(unittest.TestCase):
         result = yield defer.ensureDeferred(self.state.compute_event_context(event))
         return result
 
-    def test_make_state_cache_entry(self):
+    def test_make_state_cache_entry(self) -> None:
         "Test that calculating a prev_group and delta is correct"
 
         new_state = {
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index abd7459a8c..52424aa087 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -14,9 +14,12 @@
 
 from unittest.mock import Mock
 
-from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.internet.interfaces import IReactorTime
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
 
 from synapse.rest.client.register import register_servlets
+from synapse.server import HomeServer
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -25,7 +28,7 @@ from tests import unittest
 class TermsTestCase(unittest.HomeserverTestCase):
     servlets = [register_servlets]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = super().default_config()
         config.update(
             {
@@ -40,17 +43,21 @@ class TermsTestCase(unittest.HomeserverTestCase):
         )
         return config
 
-    def prepare(self, reactor, clock, hs):
-        self.clock = MemoryReactorClock()
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        # type-ignore: mypy-zope doesn't seem to recognise that MemoryReactorClock
+        # implements IReactorTime, via inheritance from twisted.internet.testing.Clock
+        self.clock: IReactorTime = MemoryReactorClock()  # type: ignore[assignment]
         self.hs_clock = Clock(self.clock)
         self.url = "/_matrix/client/r0/register"
         self.registration_handler = Mock()
         self.auth_handler = Mock()
         self.device_handler = Mock()
 
-    def test_ui_auth(self):
+    def test_ui_auth(self) -> None:
         # Do a UI auth request
-        request_data = {"username": "kermit", "password": "monkey"}
+        request_data: JsonDict = {"username": "kermit", "password": "monkey"}
         channel = self.make_request(b"POST", self.url, request_data)
 
         self.assertEqual(channel.code, 401, channel.result)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock
 
 
 class MockClockTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock = MockClock()
 
-    def test_advance_time(self):
+    def test_advance_time(self) -> None:
         start_time = self.clock.time()
 
         self.clock.advance_time(20)
 
         self.assertEqual(20, self.clock.time() - start_time)
 
-    def test_later(self):
+    def test_later(self) -> None:
         invoked = [0, 0]
 
-        def _cb0():
+        def _cb0() -> None:
             invoked[0] = 1
 
         self.clock.call_later(10, _cb0)
 
-        def _cb1():
+        def _cb1() -> None:
             invoked[1] = 1
 
         self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
 
         self.assertTrue(invoked[1])
 
-    def test_cancel_later(self):
+    def test_cancel_later(self) -> None:
         invoked = [0, 0]
 
-        def _cb0():
+        def _cb0() -> None:
             invoked[0] = 1
 
         t0 = self.clock.call_later(10, _cb0)
 
-        def _cb1():
+        def _cb1() -> None:
             invoked[1] = 1
 
         self.clock.call_later(20, _cb1)
diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
 
 
 class UserIDTestCase(unittest.HomeserverTestCase):
-    def test_parse(self):
+    def test_parse(self) -> None:
         user = UserID.from_string("@1234abcd:test")
 
         self.assertEqual("1234abcd", user.localpart)
         self.assertEqual("test", user.domain)
         self.assertEqual(True, self.hs.is_mine(user))
 
-    def test_parse_rejects_empty_id(self):
+    def test_parse_rejects_empty_id(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("")
 
-    def test_parse_rejects_missing_sigil(self):
+    def test_parse_rejects_missing_sigil(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("alice:example.com")
 
-    def test_parse_rejects_missing_separator(self):
+    def test_parse_rejects_missing_separator(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("@alice.example.com")
 
-    def test_validation_rejects_missing_domain(self):
+    def test_validation_rejects_missing_domain(self) -> None:
         self.assertFalse(UserID.is_valid("@alice:"))
 
-    def test_build(self):
+    def test_build(self) -> None:
         user = UserID("5678efgh", "my.domain")
 
         self.assertEqual(user.to_string(), "@5678efgh:my.domain")
 
-    def test_compare(self):
+    def test_compare(self) -> None:
         userA = UserID.from_string("@userA:my.domain")
         userAagain = UserID.from_string("@userA:my.domain")
         userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
 
 
 class RoomAliasTestCase(unittest.HomeserverTestCase):
-    def test_parse(self):
+    def test_parse(self) -> None:
         room = RoomAlias.from_string("#channel:test")
 
         self.assertEqual("channel", room.localpart)
         self.assertEqual("test", room.domain)
         self.assertEqual(True, self.hs.is_mine(room))
 
-    def test_build(self):
+    def test_build(self) -> None:
         room = RoomAlias("channel", "my.domain")
 
         self.assertEqual(room.to_string(), "#channel:my.domain")
 
-    def test_validate(self):
+    def test_validate(self) -> None:
         id_string = "#test:domain,test"
         self.assertFalse(RoomAlias.is_valid(id_string))
 
 
 class MapUsernameTestCase(unittest.TestCase):
-    def testPassThrough(self):
+    def test_pass_througuh(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
 
-    def testUpperCase(self):
+    def test_upper_case(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
         self.assertEqual(
             map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
             "t_e_s_t__1234",
         )
 
-    def testSymbols(self):
+    def test_symbols(self) -> None:
         self.assertEqual(
             map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
         )
 
-    def testLeadingUnderscore(self):
+    def test_leading_underscore(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
 
-    def testNonAscii(self):
+    def test_non_ascii(self) -> None:
         # this should work with either a unicode or a bytes
         self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
         self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
 from unittest.mock import Mock
 
 import attr
 import zope.interface
 
+from twisted.internet.interfaces import IProtocol
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
 from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
 
 from synapse.types import JsonDict
 
+if TYPE_CHECKING:
+    from sys import UnraisableHookArgs
+
 TV = TypeVar("TV")
 
 
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
     unraisable_exceptions = []
     orig_unraisablehook = sys.unraisablehook
 
-    def unraisablehook(unraisable):
+    def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
         unraisable_exceptions.append(unraisable.exc_value)
 
-    def cleanup():
+    def cleanup() -> None:
         """
         A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
         """
         sys.unraisablehook = orig_unraisablehook
         if unraisable_exceptions:
-            raise unraisable_exceptions.pop()
+            exc = unraisable_exceptions.pop()
+            assert exc is not None
+            raise exc
 
     sys.unraisablehook = unraisablehook
 
     return cleanup
 
 
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+    return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
     # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
+    async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
         if raises:
             raise raises
         return return_value
@@ -125,14 +133,14 @@ class FakeResponse:  # type: ignore[misc]
     headers: Headers = attr.Factory(Headers)
 
     @property
-    def phrase(self):
+    def phrase(self) -> bytes:
         return RESPONSES.get(self.code, b"Unknown Status")
 
     @property
-    def length(self):
+    def length(self) -> int:
         return len(self.body)
 
-    def deliverBody(self, protocol):
+    def deliverBody(self, protocol: IProtocol) -> None:
         protocol.dataReceived(self.body)
         protocol.connectionLost(Failure(ResponseDone()))
 
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import synapse.server
 from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
     membership: str,
     target: Optional[str] = None,
     extra_content: Optional[dict] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> EventBase:
     """Inject a membership event into a room."""
     if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[List[str]] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> EventBase:
     """Inject a generic event into a room
 
@@ -82,7 +82,7 @@ async def create_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[List[str]] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> Tuple[EventBase, EventContext]:
     if room_version is None:
         room_version = await hs.get_datastores().main.get_room_version_id(
@@ -92,8 +92,13 @@ async def create_event(
     builder = hs.get_event_builder_factory().for_room_version(
         KNOWN_ROOM_VERSIONS[room_version], kwargs
     )
-    event, context = await hs.get_event_creation_handler().create_new_client_event(
+    (
+        event,
+        unpersisted_context,
+    ) = await hs.get_event_creation_handler().create_new_client_event(
         builder, prev_event_ids=prev_event_ids
     )
 
+    context = await unpersisted_context.persist(event)
+
     return event, context
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@
 # limitations under the License.
 
 from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
 
 
 class TestHtmlParser(HTMLParser):
     """A generic HTML page parser which extracts useful things from the HTML"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         super().__init__()
 
         # a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
                 assert input_name
                 self.hiddens[input_name] = attr_dict["value"]
 
-    def error(_, message):
+    def error(self, message: str) -> NoReturn:
         raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
 
     tx_log = twisted.logger.Logger()
 
-    def emit(self, record):
+    def emit(self, record: logging.LogRecord) -> None:
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace("warning", "warn")
         self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
         )
 
 
-def setup_logging():
+def setup_logging() -> None:
     """Configure the python logging appropriately for the tests.
 
     (Logs will end up in _trial_temp.)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
 
 
 import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
 from unittest.mock import Mock, patch
 from urllib.parse import parse_qs
 
@@ -77,14 +77,14 @@ class FakeOidcServer:
 
         self._id_token_overrides: Dict[str, Any] = {}
 
-    def reset_mocks(self):
+    def reset_mocks(self) -> None:
         self.request.reset_mock()
         self.get_jwks_handler.reset_mock()
         self.get_metadata_handler.reset_mock()
         self.get_userinfo_handler.reset_mock()
         self.post_token_handler.reset_mock()
 
-    def patch_homeserver(self, hs: HomeServer):
+    def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
         """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
 
         This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
 
         return self._sign(logout_token)
 
-    def id_token_override(self, overrides: dict):
+    def id_token_override(self, overrides: dict) -> ContextManager[dict]:
         """Temporarily patch the ID token generated by the token endpoint."""
         return patch.object(self, "_id_token_overrides", overrides)
 
@@ -247,7 +247,7 @@ class FakeOidcServer:
         metadata: bool = False,
         token: bool = False,
         userinfo: bool = False,
-    ):
+    ) -> ContextManager[Dict[str, Mock]]:
         """A context which makes a set of endpoints return a 500 error.
 
         Args:
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0b9ad5454..2801a950a8 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
         self._storage_controllers = self.hs.get_storage_controllers()
+        assert self._storage_controllers.persistence is not None
+        self._persistence = self._storage_controllers.persistence
 
         self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
 
@@ -175,12 +177,11 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
-        self.get_success(
-            self._storage_controllers.persistence.persist_event(event, context)
-        )
+        context = self.get_success(unpersisted_context.persist(event))
+        self.get_success(self._persistence.persist_event(event, context))
         return event
 
     def _inject_room_member(
@@ -202,13 +203,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
-        self.get_success(
-            self._storage_controllers.persistence.persist_event(event, context)
-        )
+        self.get_success(self._persistence.persist_event(event, context))
         return event
 
     def _inject_message(
@@ -226,13 +226,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
-        self.get_success(
-            self._storage_controllers.persistence.persist_event(event, context)
-        )
+        self.get_success(self._persistence.persist_event(event, context))
         return event
 
     def _inject_outlier(self) -> EventBase:
@@ -250,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
         event.internal_metadata.outlier = True
         self.get_success(
-            self._storage_controllers.persistence.persist_event(
+            self._persistence.persist_event(
                 event, EventContext.for_outlier(self._storage_controllers)
             )
         )
@@ -258,7 +257,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
 
 class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
-    def test_out_of_band_invite_rejection(self):
+    def test_out_of_band_invite_rejection(self) -> None:
         # this is where we have received an invite event over federation, and then
         # rejected it.
         invite_pdu = {
diff --git a/tests/unittest.py b/tests/unittest.py
index 50aa5abda9..6625fe1688 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
 from twisted.trial import unittest
 from twisted.web.resource import Resource
 from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
 )
 from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
 
 setupdb()
 setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
 
         from tests.rest.client.utils import RestHelper
 
-        self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+        self.helper = RestHelper(
+            self.hs,
+            checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+            self.site,
+            getattr(self, "user_id", None),
+        )
 
         if hasattr(self, "user_id"):
             if self.hijack_auth:
@@ -315,7 +320,7 @@ class HomeserverTestCase(TestCase):
 
                 # This has to be a function and not just a Mock, because
                 # `self.helper.auth_user_id` is temporarily reassigned in some tests
-                async def get_requester(*args, **kwargs) -> Requester:
+                async def get_requester(*args: Any, **kwargs: Any) -> Requester:
                     assert self.helper.auth_user_id is not None
                     return create_requester(
                         user_id=UserID.from_string(self.helper.auth_user_id),
@@ -361,7 +366,9 @@ class HomeserverTestCase(TestCase):
                 store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
-    def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
+    def make_homeserver(
+        self, reactor: ThreadedMemoryReactorClock, clock: Clock
+    ) -> HomeServer:
         """
         Make and return a homeserver.
 
@@ -716,7 +723,7 @@ class HomeserverTestCase(TestCase):
         event_creator = self.hs.get_event_creation_handler()
         requester = create_requester(user)
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_creator.create_event(
                 requester,
                 {
@@ -728,7 +735,7 @@ class HomeserverTestCase(TestCase):
                 prev_event_ids=prev_event_ids,
             )
         )
-
+        context = self.get_success(unpersisted_context.persist(event))
         if soft_failed:
             event.internal_metadata.soft_failed = True
 
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9529ee53c8..5f8f4e76b5 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         self.pump()
 
         new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+        assert new_timings is not None
         self.assertEqual(new_timings.failure_ts, failure_ts)
         self.assertEqual(new_timings.retry_last_ts, failure_ts)
         self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
@@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         self.pump()
 
         new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
+        assert new_timings is not None
         self.assertEqual(new_timings.failure_ts, failure_ts)
         self.assertEqual(new_timings.retry_last_ts, retry_ts)
         self.assertGreaterEqual(
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..a0ac11bc5c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,7 +15,7 @@
 
 import atexit
 import os
-from typing import Any, Callable, Dict, List, Tuple, Union, overload
+from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, overload
 
 import attr
 from typing_extensions import Literal, ParamSpec
@@ -335,6 +335,33 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
         },
     )
 
-    event, context = await event_creation_handler.create_new_client_event(builder)
+    event, unpersisted_context = await event_creation_handler.create_new_client_event(
+        builder
+    )
+    context = await unpersisted_context.persist(event)
 
     await persistence_store.persist_event(event, context)
+
+
+T = TypeVar("T")
+
+
+def checked_cast(type: Type[T], x: object) -> T:
+    """A version of typing.cast that is checked at runtime.
+
+    We have our own function for this for two reasons:
+
+    1. typing.cast itself is deliberately a no-op at runtime, see
+       https://docs.python.org/3/library/typing.html#typing.cast
+    2. To help workaround a mypy-zope bug https://github.com/Shoobx/mypy-zope/issues/91
+       where mypy would erroneously consider `isinstance(x, type)` to be false in all
+       circumstances.
+
+    For this to make sense, `T` needs to be something that `isinstance` can check; see
+        https://docs.python.org/3/library/functions.html?highlight=isinstance#isinstance
+        https://docs.python.org/3/glossary.html#term-abstract-base-class
+        https://docs.python.org/3/library/typing.html#typing.runtime_checkable
+    for more details.
+    """
+    assert isinstance(x, type)
+    return x