summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_presence_router.py10
-rw-r--r--tests/handlers/test_oidc.py4
-rw-r--r--tests/handlers/test_user_directory.py4
-rw-r--r--tests/module_api/test_api.py122
-rw-r--r--tests/push/test_push_rule_evaluator.py8
-rw-r--r--tests/rest/admin/test_media.py9
-rw-r--r--tests/rest/admin/test_user.py4
-rw-r--r--tests/rest/media/v1/test_media_storage.py49
-rw-r--r--tests/scripts/test_new_matrix_user.py25
-rw-r--r--tests/server_notices/test_consent.py14
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py35
-rw-r--r--tests/storage/test_redaction.py24
-rw-r--r--tests/storage/test_state.py4
-rw-r--r--tests/test_distributor.py12
-rw-r--r--tests/test_event_auth.py32
-rw-r--r--tests/test_federation.py80
-rw-r--r--tests/test_mau.py35
-rw-r--r--tests/test_rust.py2
-rw-r--r--tests/test_test_utils.py16
-rw-r--r--tests/test_types.py30
-rw-r--r--tests/test_utils/__init__.py26
-rw-r--r--tests/test_utils/event_injection.py15
-rw-r--r--tests/test_utils/html_parsers.py6
-rw-r--r--tests/test_utils/logging_setup.py4
-rw-r--r--tests/test_utils/oidc.py10
-rw-r--r--tests/test_visibility.py11
-rw-r--r--tests/unittest.py2
-rw-r--r--tests/utils.py5
28 files changed, 359 insertions, 239 deletions
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index a9893def74..741bb6464a 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -31,7 +31,11 @@ from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
 from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+    FederatingHomeserverTestCase,
+    HomeserverTestCase,
+    override_config,
+)
 
 
 @attr.s
@@ -470,7 +474,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
 
 
 def send_presence_update(
-    testcase: FederatingHomeserverTestCase,
+    testcase: HomeserverTestCase,
     user_id: str,
     access_token: str,
     presence_state: str,
@@ -491,7 +495,7 @@ def send_presence_update(
 
 
 def sync_presence(
-    testcase: FederatingHomeserverTestCase,
+    testcase: HomeserverTestCase,
     user_id: str,
     since_token: Optional[StreamToken] = None,
 ) -> Tuple[List[UserPresenceState], StreamToken]:
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index adddbd002f..951caaa6b3 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         hs = self.setup_test_homeserver()
         self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
-        self.hs_patcher.start()
+        self.hs_patcher.start()  # type: ignore[attr-defined]
 
         self.handler = hs.get_oidc_handler()
         self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def tearDown(self) -> None:
-        self.hs_patcher.stop()
+        self.hs_patcher.stop()  # type: ignore[attr-defined]
         return super().tearDown()
 
     def reset_mocks(self) -> None:
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..e9be5fb504 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -949,10 +949,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(
             self.hs.get_storage_controllers().persistence.persist_event(event, context)
         )
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 8f88c0117d..cc173ebda6 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -11,9 +11,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict
 from unittest.mock import Mock
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EduTypes, EventTypes
 from synapse.api.errors import NotFoundError
@@ -21,9 +23,12 @@ from synapse.events import EventBase
 from synapse.federation.units import Transaction
 from synapse.handlers.presence import UserPresenceState
 from synapse.handlers.push_rules import InvalidRuleException
+from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, notifications, presence, profile, room
-from synapse.types import create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
+from synapse.util import Clock
 
 from tests.events.test_presence_router import send_presence_update, sync_presence
 from tests.replication._base import BaseMultiWorkerStreamTestCase
@@ -32,7 +37,19 @@ from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
 
 
-class ModuleApiTestCase(HomeserverTestCase):
+class BaseModuleApiTestCase(HomeserverTestCase):
+    """Common properties of the two test case classes."""
+
+    module_api: ModuleApi
+
+    # These are all written by _test_sending_local_online_presence_to_local_user.
+    presence_receiver_id: str
+    presence_receiver_tok: str
+    presence_sender_id: str
+    presence_sender_tok: str
+
+
+class ModuleApiTestCase(BaseModuleApiTestCase):
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -42,14 +59,14 @@ class ModuleApiTestCase(HomeserverTestCase):
         notifications.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, homeserver):
-        self.store = homeserver.get_datastores().main
-        self.module_api = homeserver.get_module_api()
-        self.event_creation_handler = homeserver.get_event_creation_handler()
-        self.sync_handler = homeserver.get_sync_handler()
-        self.auth_handler = homeserver.get_auth_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastores().main
+        self.module_api = hs.get_module_api()
+        self.event_creation_handler = hs.get_event_creation_handler()
+        self.sync_handler = hs.get_sync_handler()
+        self.auth_handler = hs.get_auth_handler()
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
         fed_transport_client.send_transaction = simple_async_mock({})
@@ -58,7 +75,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             federation_transport_client=fed_transport_client,
         )
 
-    def test_can_register_user(self):
+    def test_can_register_user(self) -> None:
         """Tests that an external module can register a user"""
         # Register a new user
         user_id, access_token = self.get_success(
@@ -88,16 +105,17 @@ class ModuleApiTestCase(HomeserverTestCase):
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
 
-    def test_can_register_admin_user(self):
+    def test_can_register_admin_user(self) -> None:
         user_id = self.register_user(
             "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
         )
 
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
 
-    def test_can_set_admin(self):
+    def test_can_set_admin(self) -> None:
         user_id = self.register_user(
             "alice_wants_admin",
             "1234",
@@ -107,16 +125,17 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.get_success(self.module_api.set_user_admin(user_id, True))
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, True)
 
-    def test_can_set_displayname(self):
+    def test_can_set_displayname(self) -> None:
         localpart = "alice_wants_a_new_displayname"
         user_id = self.register_user(
             localpart, "1234", displayname="Alice", admin=False
         )
         found_userinfo = self.get_success(self.module_api.get_userinfo_by_id(user_id))
-
+        assert found_userinfo is not None
         self.get_success(
             self.module_api.set_displayname(
                 found_userinfo.user_id, "Bob", deactivation=False
@@ -128,17 +147,18 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.assertEqual(found_profile.display_name, "Bob")
 
-    def test_get_userinfo_by_id(self):
+    def test_get_userinfo_by_id(self) -> None:
         user_id = self.register_user("alice", "1234")
         found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        assert found_user is not None
         self.assertEqual(found_user.user_id.to_string(), user_id)
         self.assertIdentical(found_user.is_admin, False)
 
-    def test_get_userinfo_by_id__no_user_found(self):
+    def test_get_userinfo_by_id__no_user_found(self) -> None:
         found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
         self.assertIsNone(found_user)
 
-    def test_get_user_ip_and_agents(self):
+    def test_get_user_ip_and_agents(self) -> None:
         user_id = self.register_user("test_get_user_ip_and_agents_user", "1234")
 
         # Initially, we should have no ip/agent for our user.
@@ -185,7 +205,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # we should only find the second ip, agent.
         info = self.get_success(
             self.module_api.get_user_ip_and_agents(
-                user_id, (last_seen_1 + last_seen_2) / 2
+                user_id, (last_seen_1 + last_seen_2) // 2
             )
         )
         self.assertEqual(len(info), 1)
@@ -200,7 +220,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertEqual(info, [])
 
-    def test_get_user_ip_and_agents__no_user_found(self):
+    def test_get_user_ip_and_agents__no_user_found(self) -> None:
         info = self.get_success(
             self.module_api.get_user_ip_and_agents(
                 "@test_get_user_ip_and_agents_user_nonexistent:example.com"
@@ -208,10 +228,10 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertEqual(info, [])
 
-    def test_sending_events_into_room(self):
+    def test_sending_events_into_room(self) -> None:
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent
-        self.event_creation_handler.create_and_send_nonmember_event = Mock(
+        self.event_creation_handler.create_and_send_nonmember_event = Mock(  # type: ignore[assignment]
             spec=[],
             side_effect=self.event_creation_handler.create_and_send_nonmember_event,
         )
@@ -222,7 +242,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         room_id = self.helper.create_room_as(user_id, tok=tok)
 
         # Create and send a non-state event
-        content = {"body": "I am a puppet", "msgtype": "m.text"}
+        content: JsonDict = {"body": "I am a puppet", "msgtype": "m.text"}
         event_dict = {
             "room_id": room_id,
             "type": "m.room.message",
@@ -265,7 +285,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             "sender": user_id,
             "state_key": "",
         }
-        event: EventBase = self.get_success(
+        event = self.get_success(
             self.module_api.create_and_send_event_into_room(event_dict)
         )
         self.assertEqual(event.sender, user_id)
@@ -303,7 +323,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             self.module_api.create_and_send_event_into_room(event_dict), Exception
         )
 
-    def test_public_rooms(self):
+    def test_public_rooms(self) -> None:
         """Tests that a room can be added and removed from the public rooms list,
         as well as have its public rooms directory state queried.
         """
@@ -350,13 +370,13 @@ class ModuleApiTestCase(HomeserverTestCase):
         )
         self.assertFalse(is_in_public_rooms)
 
-    def test_send_local_online_presence_to(self):
+    def test_send_local_online_presence_to(self) -> None:
         # Test sending local online presence to users from the main process
         _test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
 
     # Enable federation sending on the main process.
     @override_config({"federation_sender_instances": None})
-    def test_send_local_online_presence_to_federation(self):
+    def test_send_local_online_presence_to_federation(self) -> None:
         """Tests that send_local_presence_to_users sends local online presence to remote users."""
         # Create a user who will send presence updates
         self.presence_sender_id = self.register_user("presence_sender1", "monkey")
@@ -431,7 +451,7 @@ class ModuleApiTestCase(HomeserverTestCase):
 
         self.assertTrue(found_update)
 
-    def test_update_membership(self):
+    def test_update_membership(self) -> None:
         """Tests that the module API can update the membership of a user in a room."""
         peter = self.register_user("peter", "hackme")
         lesley = self.register_user("lesley", "hackme")
@@ -554,7 +574,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertEqual(res["displayname"], "simone")
         self.assertIsNone(res["avatar_url"])
 
-    def test_update_room_membership_remote_join(self):
+    def test_update_room_membership_remote_join(self) -> None:
         """Test that the module API can join a remote room."""
         # Necessary to fake a remote join.
         fake_stream_id = 1
@@ -582,7 +602,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Check that a remote join was attempted.
         self.assertEqual(mocked_remote_join.call_count, 1)
 
-    def test_get_room_state(self):
+    def test_get_room_state(self) -> None:
         """Tests that a module can retrieve the state of a room through the module API."""
         user_id = self.register_user("peter", "hackme")
         tok = self.login("peter", "hackme")
@@ -677,7 +697,7 @@ class ModuleApiTestCase(HomeserverTestCase):
             self.module_api.check_push_rule_actions(["foo"])
 
         with self.assertRaises(InvalidRuleException):
-            self.module_api.check_push_rule_actions({"foo": "bar"})
+            self.module_api.check_push_rule_actions([{"foo": "bar"}])
 
         self.module_api.check_push_rule_actions(["notify"])
 
@@ -756,7 +776,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         self.assertIsNone(room_alias)
 
 
-class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
+class ModuleApiWorkerTestCase(BaseModuleApiTestCase, BaseMultiWorkerStreamTestCase):
     """For testing ModuleApi functionality in a multi-worker setup"""
 
     servlets = [
@@ -766,7 +786,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
         presence.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         conf = super().default_config()
         conf["stream_writers"] = {"presence": ["presence_writer"]}
         conf["instance_map"] = {
@@ -774,18 +794,18 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
         }
         return conf
 
-    def prepare(self, reactor, clock, homeserver):
-        self.module_api = homeserver.get_module_api()
-        self.sync_handler = homeserver.get_sync_handler()
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.module_api = hs.get_module_api()
+        self.sync_handler = hs.get_sync_handler()
 
-    def test_send_local_online_presence_to_workers(self):
+    def test_send_local_online_presence_to_workers(self) -> None:
         # Test sending local online presence to users from a worker process
         _test_sending_local_online_presence_to_local_user(self, test_with_workers=True)
 
 
 def _test_sending_local_online_presence_to_local_user(
-    test_case: HomeserverTestCase, test_with_workers: bool = False
-):
+    test_case: BaseModuleApiTestCase, test_with_workers: bool = False
+) -> None:
     """Tests that send_local_presence_to_users sends local online presence to local users.
 
     This simultaneously tests two different usecases:
@@ -852,6 +872,7 @@ def _test_sending_local_online_presence_to_local_user(
         # Replicate the current sync presence token from the main process to the worker process.
         # We need to do this so that the worker process knows the current presence stream ID to
         # insert into the database when we call ModuleApi.send_local_online_presence_to.
+        assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
         test_case.replicate()
 
     # Syncing again should result in no presence updates
@@ -868,6 +889,7 @@ def _test_sending_local_online_presence_to_local_user(
 
     # Determine on which process (main or worker) to call ModuleApi.send_local_online_presence_to on
     if test_with_workers:
+        assert isinstance(test_case, BaseMultiWorkerStreamTestCase)
         module_api_to_use = worker_hs.get_module_api()
     else:
         module_api_to_use = test_case.module_api
@@ -875,12 +897,11 @@ def _test_sending_local_online_presence_to_local_user(
     # Trigger sending local online presence. We expect this information
     # to be saved to the database where all processes can access it.
     # Note that we're syncing via the master.
-    d = module_api_to_use.send_local_online_presence_to(
-        [
-            test_case.presence_receiver_id,
-        ]
+    d = defer.ensureDeferred(
+        module_api_to_use.send_local_online_presence_to(
+            [test_case.presence_receiver_id],
+        )
     )
-    d = defer.ensureDeferred(d)
 
     if test_with_workers:
         # In order for the required presence_set_state replication request to occur between the
@@ -897,7 +918,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update: UserPresenceState = presence_updates[0]
+    presence_update = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -908,7 +929,7 @@ def _test_sending_local_online_presence_to_local_user(
     )
     test_case.assertEqual(len(presence_updates), 1)
 
-    presence_update: UserPresenceState = presence_updates[0]
+    presence_update = presence_updates[0]
     test_case.assertEqual(presence_update.user_id, test_case.presence_sender_id)
     test_case.assertEqual(presence_update.state, "online")
 
@@ -936,12 +957,13 @@ def _test_sending_local_online_presence_to_local_user(
     test_case.assertEqual(len(presence_updates), 1)
 
     # Now trigger sending local online presence.
-    d = module_api_to_use.send_local_online_presence_to(
-        [
-            test_case.presence_receiver_id,
-        ]
+    d = defer.ensureDeferred(
+        module_api_to_use.send_local_online_presence_to(
+            [
+                test_case.presence_receiver_id,
+            ]
+        )
     )
-    d = defer.ensureDeferred(d)
 
     if test_with_workers:
         # In order for the required presence_set_state replication request to occur between the
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index da33423871..516b65cc3c 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -48,6 +48,14 @@ class FlattenDictTestCase(unittest.TestCase):
         input = {"foo": {"bar": "abc"}}
         self.assertEqual({"foo.bar": "abc"}, _flatten_dict(input))
 
+        # If a field has a dot in it, escape it.
+        input = {"m.foo": {"b\\ar": "abc"}}
+        self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input))
+        self.assertEqual(
+            {"m\\.foo.b\\\\ar": "abc"},
+            _flatten_dict(input, msc3783_escape_event_match_key=True),
+        )
+
     def test_non_string(self) -> None:
         """Non-string items are dropped."""
         input: Dict[str, Any] = {
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index aadb31ca83..db77a45ae3 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -213,7 +213,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         self.admin_user_tok = self.login("admin", "pass")
 
         self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
-        self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
+        self.url = "/_synapse/admin/v1/media/delete"
+        self.legacy_url = "/_synapse/admin/v1/media/%s/delete" % self.server_name
 
         # Move clock up to somewhat realistic time
         self.reactor.advance(1000000000)
@@ -332,11 +333,13 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
         )
 
-    def test_delete_media_never_accessed(self) -> None:
+    @parameterized.expand([(True,), (False,)])
+    def test_delete_media_never_accessed(self, use_legacy_url: bool) -> None:
         """
         Tests that media deleted if it is older than `before_ts` and never accessed
         `last_access_ts` is `NULL` and `created_ts` < `before_ts`
         """
+        url = self.legacy_url if use_legacy_url else self.url
 
         # upload and do not access
         server_and_media_id = self._create_media()
@@ -351,7 +354,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         now_ms = self.clock.time_msec()
         channel = self.make_request(
             "POST",
-            self.url + "?before_ts=" + str(now_ms),
+            url + "?before_ts=" + str(now_ms),
             access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..b50406e129 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2934,10 +2934,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(storage_controllers.persistence.persist_event(event, context))
 
         # Now get rooms
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index d18fc13c21..17a3b06a8e 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -16,7 +16,7 @@ import shutil
 import tempfile
 from binascii import unhexlify
 from io import BytesIO
-from typing import Any, BinaryIO, Dict, List, Optional, Union
+from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
 from unittest.mock import Mock
 from urllib import parse
 
@@ -32,6 +32,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.errors import Codes
 from synapse.events import EventBase
 from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.http.types import QueryParams
 from synapse.logging.context import make_deferred_yieldable
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
@@ -41,7 +42,7 @@ from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 from synapse.server import HomeServer
-from synapse.types import RoomAlias
+from synapse.types import JsonDict, RoomAlias
 from synapse.util import Clock
 
 from tests import unittest
@@ -201,36 +202,46 @@ class _TestImage:
     ],
 )
 class MediaRepoTests(unittest.HomeserverTestCase):
-
+    test_image: ClassVar[_TestImage]
     hijack_auth = True
     user_id = "@test:user"
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
-        self.fetches = []
+        self.fetches: List[
+            Tuple[
+                "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]",
+                str,
+                str,
+                Optional[QueryParams],
+            ]
+        ] = []
 
         def get_file(
             destination: str,
             path: str,
             output_stream: BinaryIO,
-            args: Optional[Dict[str, Union[str, List[str]]]] = None,
+            args: Optional[QueryParams] = None,
+            retry_on_dns_fail: bool = True,
             max_size: Optional[int] = None,
-        ) -> Deferred:
-            """
-            Returns tuple[int,dict,str,int] of file length, response headers,
-            absolute URI, and response code.
-            """
+            ignore_backoff: bool = False,
+        ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
+            """A mock for MatrixFederationHttpClient.get_file."""
 
-            def write_to(r):
+            def write_to(
+                r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
+            ) -> Tuple[int, Dict[bytes, List[bytes]]]:
                 data, response = r
                 output_stream.write(data)
                 return response
 
-            d = Deferred()
-            d.addCallback(write_to)
+            d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
             self.fetches.append((d, destination, path, args))
-            return make_deferred_yieldable(d)
+            # Note that this callback changes the value held by d.
+            d_after_callback = d.addCallback(write_to)
+            return make_deferred_yieldable(d_after_callback)
 
+        # Mock out the homeserver's MatrixFederationHttpClient
         client = Mock()
         client.get_file = get_file
 
@@ -461,6 +472,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         # Synapse should regenerate missing thumbnails.
         origin, media_id = self.media_id.split("/")
         info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
+        assert info is not None
         file_id = info["filesystem_id"]
 
         thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
@@ -581,7 +593,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                         "thumbnail_method": method,
                         "thumbnail_type": self.test_image.content_type,
                         "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail1{self.test_image.extension}",
+                        "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}",
                     },
                     {
                         "thumbnail_width": 32,
@@ -589,10 +601,10 @@ class MediaRepoTests(unittest.HomeserverTestCase):
                         "thumbnail_method": method,
                         "thumbnail_type": self.test_image.content_type,
                         "thumbnail_length": 256,
-                        "filesystem_id": f"thumbnail2{self.test_image.extension}",
+                        "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}",
                     },
                 ],
-                file_id=f"image{self.test_image.extension}",
+                file_id=f"image{self.test_image.extension.decode()}",
                 url_cache=None,
                 server_name=None,
             )
@@ -637,6 +649,7 @@ class TestSpamCheckerLegacy:
         self.config = config
         self.api = api
 
+    @staticmethod
     def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
         return config
 
@@ -748,7 +761,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
 
     async def check_media_file_for_spam(
         self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
-    ) -> Union[Codes, Literal["NOT_SPAM"]]:
+    ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
         buf = BytesIO()
         await file_wrapper.write_chunks_to(buf.write)
 
diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 22f99c6ab1..3285f2433c 100644
--- a/tests/scripts/test_new_matrix_user.py
+++ b/tests/scripts/test_new_matrix_user.py
@@ -12,29 +12,33 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List
+from typing import List, Optional
 from unittest.mock import Mock, patch
 
 from synapse._scripts.register_new_matrix_user import request_registration
+from synapse.types import JsonDict
 
 from tests.unittest import TestCase
 
 
 class RegisterTestCase(TestCase):
-    def test_success(self):
+    def test_success(self) -> None:
         """
         The script will fetch a nonce, and then generate a MAC with it, and then
         post that MAC.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 200
             r.json = lambda: {"nonce": "a"}
             return r
 
-        def post(url, json=None, verify=None):
+        def post(
+            url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+        ) -> Mock:
             # Make sure we are sent the correct info
+            assert json is not None
             self.assertEqual(json["username"], "user")
             self.assertEqual(json["password"], "pass")
             self.assertEqual(json["nonce"], "a")
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
         # sys.exit shouldn't have been called.
         self.assertEqual(err_code, [])
 
-    def test_failure_nonce(self):
+    def test_failure_nonce(self) -> None:
         """
         If the script fails to fetch a nonce, it throws an error and quits.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 404
             r.reason = "Not Found"
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
         self.assertIn("ERROR! Received 404 Not Found", out)
         self.assertNotIn("Success!", out)
 
-    def test_failure_post(self):
+    def test_failure_post(self) -> None:
         """
         The script will fetch a nonce, and then if the final POST fails, will
         report an error and quit.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 200
             r.json = lambda: {"nonce": "a"}
             return r
 
-        def post(url, json=None, verify=None):
+        def post(
+            url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+        ) -> Mock:
             # Make sure we are sent the correct info
+            assert json is not None
             self.assertEqual(json["username"], "user")
             self.assertEqual(json["password"], "pass")
             self.assertEqual(json["nonce"], "a")
diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py
index 58b399a043..6540ed53f1 100644
--- a/tests/server_notices/test_consent.py
+++ b/tests/server_notices/test_consent.py
@@ -14,8 +14,12 @@
 
 import os
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         tmpdir = self.mktemp()
         os.mkdir(tmpdir)
@@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
             "room_name": "Server Notices",
         }
 
-        hs = self.setup_test_homeserver(config=config)
-
-        return hs
+        return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("bob", "abc123")
         self.access_token = self.login("bob", "abc123")
 
-    def test_get_sync_message(self):
+    def test_get_sync_message(self) -> None:
         """
         When user consent server notices are enabled, a sync will cause a notice
         to fire (in a room which the user is invited to). The notice contains
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index dadc6efcbf..5b76383d76 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -24,6 +24,7 @@ from synapse.server import HomeServer
 from synapse.server_notices.resource_limits_server_notices import (
     ResourceLimitsServerNotices,
 )
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -33,7 +34,7 @@ from tests.utils import default_config
 
 
 class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = default_config("test")
 
         config.update(
@@ -86,18 +87,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))  # type: ignore[assignment]
 
     @override_config({"hs_disabled": True})
-    def test_maybe_send_server_notice_disabled_hs(self):
+    def test_maybe_send_server_notice_disabled_hs(self) -> None:
         """If the HS is disabled, we should not send notices"""
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         self._send_notice.assert_not_called()
 
     @override_config({"limit_usage_by_mau": False})
-    def test_maybe_send_server_notice_to_user_flag_off(self):
+    def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
         """If mau limiting is disabled, we should not send notices"""
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+    def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
         """Test when user has blocked notice, but should have it removed"""
 
         self._rlsn._auth_blocking.check_auth_blocking = Mock(
@@ -114,7 +115,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
         self._send_notice.assert_called_once()
 
-    def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+    def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
         """
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
@@ -134,7 +135,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
 
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+    def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
         """
         Test when user does not have blocked notice, but should have one
         """
@@ -147,7 +148,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         # Would be better to check contents, but 2 calls == set blocking event
         self.assertEqual(self._send_notice.call_count, 2)
 
-    def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+    def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
@@ -159,7 +160,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
 
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+    def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
         """
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
@@ -175,7 +176,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._send_notice.assert_not_called()
 
     @override_config({"mau_limit_alerting": False})
-    def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+    def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
+        self,
+    ) -> None:
         """
         Test that when server is over MAU limit and alerting is suppressed, then
         an alert message is not sent into the room
@@ -191,7 +194,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self.assertEqual(self._send_notice.call_count, 0)
 
     @override_config({"mau_limit_alerting": False})
-    def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+    def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
@@ -207,7 +210,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self.assertEqual(self._send_notice.call_count, 2)
 
     @override_config({"mau_limit_alerting": False})
-    def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+    def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
+        self,
+    ) -> None:
         """
         When the room is already in a blocked state, test that when alerting
         is suppressed that the room is returned to an unblocked state.
@@ -242,7 +247,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         c = super().default_config()
         c["server_notices"] = {
             "system_mxid_localpart": "server",
@@ -270,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         self.user_id = "@user_id:test"
 
-    def test_server_notice_only_sent_once(self):
+    def test_server_notice_only_sent_once(self) -> None:
         self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
 
         self.store.user_last_seen_monthly_active = Mock(
@@ -306,7 +311,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         self.assertEqual(count, 1)
 
-    def test_no_invite_without_notice(self):
+    def test_no_invite_without_notice(self) -> None:
         """Tests that a user doesn't get invited to a server notices room without a
         server notice being sent.
 
@@ -328,7 +333,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         m.assert_called_once_with(user_id)
 
-    def test_invite_with_notice(self):
+    def test_invite_with_notice(self) -> None:
         """Tests that, if the MAU limit is hit, the server notices user invites each user
         to a room in which it has sent a notice.
         """
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             def internal_metadata(self) -> _EventInternalMetadata:
                 return self._base_builder.internal_metadata
 
-        event_1, context_1 = self.get_success(
+        event_1, unpersisted_context_1 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
+        context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
         self.get_success(self._persistence.persist_event(event_1, context_1))
 
-        event_2, context_2 = self.get_success(
+        event_2, unpersisted_context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
+
+        context_2 = self.get_success(unpersisted_context_2.persist(event_2))
         self.get_success(self._persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        redaction_event, context = self.get_success(
+        redaction_event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(redaction_event))
+
         self.get_success(self._persistence.persist_event(redaction_event, context))
 
         # Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         assert self.storage.persistence is not None
         self.get_success(self.storage.persistence.persist_event(event, context))
 
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 31546ea52b..a248f1d277 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -21,10 +21,10 @@ from . import unittest
 
 
 class DistributorTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.dist = Distributor()
 
-    def test_signal_dispatch(self):
+    def test_signal_dispatch(self) -> None:
         self.dist.declare("alert")
 
         observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
         self.dist.fire("alert", 1, 2, 3)
         observer.assert_called_with(1, 2, 3)
 
-    def test_signal_catch(self):
+    def test_signal_catch(self) -> None:
         self.dist.declare("alarm")
 
         observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
             self.assertEqual(mock_logger.warning.call_count, 1)
             self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
 
-    def test_signal_prereg(self):
+    def test_signal_prereg(self) -> None:
         observer = Mock()
         self.dist.observe("flare", observer)
 
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
 
         observer.assert_called_with(4, 5)
 
-    def test_signal_undeclared(self):
-        def code():
+    def test_signal_undeclared(self) -> None:
+        def code() -> None:
             self.dist.fire("notification")
 
         self.assertRaises(KeyError, code)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 0a7937f1cc..2860564afc 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
 class _StubEventSourceStore:
     """A stub implementation of the EventSourceStore"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._store: Dict[str, EventBase] = {}
 
-    def add_event(self, event: EventBase):
+    def add_event(self, event: EventBase) -> None:
         self._store[event.event_id] = event
 
-    def add_events(self, events: Iterable[EventBase]):
+    def add_events(self, events: Iterable[EventBase]) -> None:
         for event in events:
             self._store[event.event_id] = event
 
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
 
 
 class EventAuthTestCase(unittest.TestCase):
-    def test_rejected_auth_events(self):
+    def test_rejected_auth_events(self) -> None:
         """
         Events that refer to rejected events in their auth events are rejected
         """
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
                 )
             )
 
-    def test_create_event_with_prev_events(self):
+    def test_create_event_with_prev_events(self) -> None:
         """A create event with prev_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event)
             )
 
-    def test_duplicate_auth_events(self):
+    def test_duplicate_auth_events(self) -> None:
         """Events with duplicate auth_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event2)
             )
 
-    def test_unexpected_auth_events(self):
+    def test_unexpected_auth_events(self) -> None:
         """Events with excess auth_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event)
             )
 
-    def test_random_users_cannot_send_state_before_first_pl(self):
+    def test_random_users_cannot_send_state_before_first_pl(self) -> None:
         """
         Check that, before the first PL lands, the creator is the only user
         that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_state_default_level(self):
+    def test_state_default_level(self) -> None:
         """
         Check that users above the state_default level can send state and
         those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_alias_event(self):
+    def test_alias_event(self) -> None:
         """Alias events have special behavior up through room version 6."""
         creator = "@creator:example.com"
         other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_msc2432_alias_event(self):
+    def test_msc2432_alias_event(self) -> None:
         """After MSC2432, alias events have no special behavior."""
         creator = "@creator:example.com"
         other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
             )
 
     @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
-    def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+    def test_notifications(
+        self, room_version: RoomVersion, allow_modification: bool
+    ) -> None:
         """
         Notifications power levels get checked due to MSC2209.
         """
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
             with self.assertRaises(AuthError):
                 event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
 
-    def test_join_rules_public(self):
+    def test_join_rules_public(self) -> None:
         """
         Test joining a public room.
         """
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events.values(),
         )
 
-    def test_join_rules_invite(self):
+    def test_join_rules_invite(self) -> None:
         """
         Test joining an invite only room.
         """
@@ -835,7 +837,7 @@ def _power_levels_event(
     )
 
 
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
     data = {
         "room_id": TEST_ROOM_ID,
         **_maybe_get_event_id_dict_for_room_version(room_version),
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 80e5c590d8..ddb43c8c98 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -12,53 +12,48 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Optional, Union
 from unittest.mock import Mock
 
 from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import FederationError
 from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import event_from_pdu_json
+from synapse.http.types import QueryParams
 from synapse.logging.context import LoggingContext
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
 from tests import unittest
-from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
 from tests.test_utils import make_awaitable
 
 
 class MessageAcceptTests(unittest.HomeserverTestCase):
-    def setUp(self):
-
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock()
-        self.reactor = ThreadedMemoryReactorClock()
-        self.hs_clock = Clock(self.reactor)
-        self.homeserver = setup_test_homeserver(
-            self.addCleanup,
-            federation_http_client=self.http_client,
-            clock=self.hs_clock,
-            reactor=self.reactor,
-        )
+        return self.setup_test_homeserver(federation_http_client=self.http_client)
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         user_id = UserID("us", "test")
         our_user = create_requester(user_id)
-        room_creator = self.homeserver.get_room_creation_handler()
+        room_creator = self.hs.get_room_creation_handler()
         self.room_id = self.get_success(
             room_creator.create_room(
                 our_user, room_creator._presets_dict["public_chat"], ratelimit=False
             )
         )[0]["room_id"]
 
-        self.store = self.homeserver.get_datastores().main
+        self.store = self.hs.get_datastores().main
 
         # Figure out what the most recent event is
         most_recent = self.get_success(
-            self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
-                self.room_id
-            )
+            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )[0]
 
         join_event = make_event_from_dict(
@@ -78,14 +73,16 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        self.handler = self.homeserver.get_federation_handler()
-        federation_event_handler = self.homeserver.get_federation_event_handler()
+        self.handler = self.hs.get_federation_handler()
+        federation_event_handler = self.hs.get_federation_event_handler()
 
-        async def _check_event_auth(origin, event, context):
+        async def _check_event_auth(
+            origin: Optional[str], event: EventBase, context: EventContext
+        ) -> None:
             pass
 
         federation_event_handler._check_event_auth = _check_event_auth
-        self.client = self.homeserver.get_federation_client()
+        self.client = self.hs.get_federation_client()
         self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
             lambda dest, pdus, **k: succeed(pdus)
         )
@@ -104,16 +101,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             "$join:test.serv",
         )
 
-    def test_cant_hide_direct_ancestors(self):
+    def test_cant_hide_direct_ancestors(self) -> None:
         """
         If you send a message, you must be able to provide the direct
         prev_events that said event references.
         """
 
-        async def post_json(destination, path, data, headers=None, timeout=0):
+        async def post_json(
+            destination: str,
+            path: str,
+            data: Optional[JsonDict] = None,
+            long_retries: bool = False,
+            timeout: Optional[int] = None,
+            ignore_backoff: bool = False,
+            args: Optional[QueryParams] = None,
+        ) -> Union[JsonDict, list]:
             # If it asks us for new missing events, give them NOTHING
             if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                 return {"events": []}
+            return {}
 
         self.http_client.post_json = post_json
 
@@ -138,7 +144,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        federation_event_handler = self.homeserver.get_federation_event_handler()
+        federation_event_handler = self.hs.get_federation_event_handler()
         with LoggingContext("test-context"):
             failure = self.get_failure(
                 federation_event_handler.on_receive_pdu("test.serv", lying_event),
@@ -158,7 +164,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
         self.assertEqual(extrem[0], "$join:test.serv")
 
-    def test_retry_device_list_resync(self):
+    def test_retry_device_list_resync(self) -> None:
         """Tests that device lists are marked as stale if they couldn't be synced, and
         that stale device lists are retried periodically.
         """
@@ -171,24 +177,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # When this function is called, increment the number of resync attempts (only if
         # we're querying devices for the right user ID), then raise a
         # NotRetryingDestination error to fail the resync gracefully.
-        def query_user_devices(destination, user_id):
+        def query_user_devices(
+            destination: str, user_id: str, timeout: int = 30000
+        ) -> JsonDict:
             if user_id == remote_user_id:
                 self.resync_attempts += 1
 
             raise NotRetryingDestination(0, 0, destination)
 
         # Register the mock on the federation client.
-        federation_client = self.homeserver.get_federation_client()
+        federation_client = self.hs.get_federation_client()
         federation_client.query_user_devices = Mock(side_effect=query_user_devices)
 
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
-        store = self.homeserver.get_datastores().main
+        store = self.hs.get_datastores().main
         store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
 
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
-        device_list_updater = self.homeserver.get_device_handler().device_list_updater
+        device_list_updater = self.hs.get_device_handler().device_list_updater
         self.get_success(
             device_list_updater.incoming_device_list_update(
                 origin=remote_origin,
@@ -218,7 +226,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.reactor.advance(30)
         self.assertEqual(self.resync_attempts, 2)
 
-    def test_cross_signing_keys_retry(self):
+    def test_cross_signing_keys_retry(self) -> None:
         """Tests that resyncing a device list correctly processes cross-signing keys from
         the remote server.
         """
@@ -227,7 +235,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
         # Register mock device list retrieval on the federation client.
-        federation_client = self.homeserver.get_federation_client()
+        federation_client = self.hs.get_federation_client()
         federation_client.query_user_devices = Mock(
             return_value=make_awaitable(
                 {
@@ -252,7 +260,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         # Resync the device list.
-        device_handler = self.homeserver.get_device_handler()
+        device_handler = self.hs.get_device_handler()
         self.get_success(
             device_handler.device_list_updater.user_device_resync(remote_user_id),
         )
@@ -279,7 +287,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
 
 class StripUnsignedFromEventsTestCase(unittest.TestCase):
-    def test_strip_unauthorized_unsigned_values(self):
+    def test_strip_unauthorized_unsigned_values(self) -> None:
         event1 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
@@ -296,7 +304,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
         # Make sure unauthorized fields are stripped from unsigned
         self.assertNotIn("more warez", filtered_event.unsigned)
 
-    def test_strip_event_maintains_allowed_fields(self):
+    def test_strip_event_maintains_allowed_fields(self) -> None:
         event2 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
@@ -323,7 +331,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
         self.assertIn("invite_room_state", filtered_event2.unsigned)
         self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
 
-    def test_strip_event_removes_fields_based_on_event_type(self):
+    def test_strip_event_removes_fields_based_on_event_type(self) -> None:
         event3 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
diff --git a/tests/test_mau.py b/tests/test_mau.py
index f14fcb7db9..4e7665a22b 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -14,12 +14,17 @@
 
 """Tests REST events for /rooms paths."""
 
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
@@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
     servlets = [register.register_servlets, sync.register_servlets]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = default_config("test")
 
         config.update(
@@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.store = homeserver.get_datastores().main
 
-    def test_simple_deny_mau(self):
+    def test_simple_deny_mau(self) -> None:
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.code, 403)
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
-    def test_as_ignores_mau(self):
+    def test_as_ignores_mau(self) -> None:
         """Test that application services can still create users when the MAU
         limit has been reached. This only works when application service
         user ip tracking is disabled.
@@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         self.create_user("as_kermit4", token=as_token, appservice=True)
 
-    def test_allowed_after_a_month_mau(self):
+    def test_allowed_after_a_month_mau(self) -> None:
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.do_sync_for_user(token3)
 
     @override_config({"mau_trial_days": 1})
-    def test_trial_delay(self):
+    def test_trial_delay(self) -> None:
         # We should be able to register more than the limit initially
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     @override_config({"mau_trial_days": 1})
-    def test_trial_users_cant_come_back(self):
+    def test_trial_users_cant_come_back(self) -> None:
         self.hs.config.server.mau_trial_days = 1
 
         # We should be able to register more than the limit initially
@@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         # max_mau_value should not matter
         {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
     )
-    def test_tracked_but_not_limited(self):
+    def test_tracked_but_not_limited(self) -> None:
         # Simply being able to create 2 users indicates that the
         # limit was not reached.
         token1 = self.create_user("kermit1")
@@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
             "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
         }
     )
-    def test_as_trial_days(self):
+    def test_as_trial_days(self) -> None:
         user_tokens: List[str] = []
 
-        def advance_time_and_sync():
+        def advance_time_and_sync() -> None:
             self.reactor.advance(24 * 60 * 61)
             for token in user_tokens:
                 self.do_sync_for_user(token)
@@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
             },
         )
 
-    def create_user(self, localpart, token=None, appservice=False):
+    def create_user(
+        self, localpart: str, token: Optional[str] = None, appservice: bool = False
+    ) -> str:
         request_data = {
             "username": localpart,
             "password": "monkey",
@@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         return access_token
 
-    def do_sync_for_user(self, token):
+    def do_sync_for_user(self, token: str) -> None:
         channel = self.make_request("GET", "/sync", access_token=token)
 
         if channel.code != 200:
diff --git a/tests/test_rust.py b/tests/test_rust.py
index 55d8b6b28c..67443b6280 100644
--- a/tests/test_rust.py
+++ b/tests/test_rust.py
@@ -6,6 +6,6 @@ from tests import unittest
 class RustTestCase(unittest.TestCase):
     """Basic tests to ensure that we can call into Rust code."""
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         result = sum_as_string(1, 2)
         self.assertEqual("3", result)
diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py
index d04bcae0fa..5cd698147e 100644
--- a/tests/test_test_utils.py
+++ b/tests/test_test_utils.py
@@ -17,25 +17,25 @@ from tests.utils import MockClock
 
 
 class MockClockTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock = MockClock()
 
-    def test_advance_time(self):
+    def test_advance_time(self) -> None:
         start_time = self.clock.time()
 
         self.clock.advance_time(20)
 
         self.assertEqual(20, self.clock.time() - start_time)
 
-    def test_later(self):
+    def test_later(self) -> None:
         invoked = [0, 0]
 
-        def _cb0():
+        def _cb0() -> None:
             invoked[0] = 1
 
         self.clock.call_later(10, _cb0)
 
-        def _cb1():
+        def _cb1() -> None:
             invoked[1] = 1
 
         self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
 
         self.assertTrue(invoked[1])
 
-    def test_cancel_later(self):
+    def test_cancel_later(self) -> None:
         invoked = [0, 0]
 
-        def _cb0():
+        def _cb0() -> None:
             invoked[0] = 1
 
         t0 = self.clock.call_later(10, _cb0)
 
-        def _cb1():
+        def _cb1() -> None:
             invoked[1] = 1
 
         self.clock.call_later(20, _cb1)
diff --git a/tests/test_types.py b/tests/test_types.py
index 1111169384..c491cc9a96 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
 
 
 class UserIDTestCase(unittest.HomeserverTestCase):
-    def test_parse(self):
+    def test_parse(self) -> None:
         user = UserID.from_string("@1234abcd:test")
 
         self.assertEqual("1234abcd", user.localpart)
         self.assertEqual("test", user.domain)
         self.assertEqual(True, self.hs.is_mine(user))
 
-    def test_parse_rejects_empty_id(self):
+    def test_parse_rejects_empty_id(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("")
 
-    def test_parse_rejects_missing_sigil(self):
+    def test_parse_rejects_missing_sigil(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("alice:example.com")
 
-    def test_parse_rejects_missing_separator(self):
+    def test_parse_rejects_missing_separator(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("@alice.example.com")
 
-    def test_validation_rejects_missing_domain(self):
+    def test_validation_rejects_missing_domain(self) -> None:
         self.assertFalse(UserID.is_valid("@alice:"))
 
-    def test_build(self):
+    def test_build(self) -> None:
         user = UserID("5678efgh", "my.domain")
 
         self.assertEqual(user.to_string(), "@5678efgh:my.domain")
 
-    def test_compare(self):
+    def test_compare(self) -> None:
         userA = UserID.from_string("@userA:my.domain")
         userAagain = UserID.from_string("@userA:my.domain")
         userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
 
 
 class RoomAliasTestCase(unittest.HomeserverTestCase):
-    def test_parse(self):
+    def test_parse(self) -> None:
         room = RoomAlias.from_string("#channel:test")
 
         self.assertEqual("channel", room.localpart)
         self.assertEqual("test", room.domain)
         self.assertEqual(True, self.hs.is_mine(room))
 
-    def test_build(self):
+    def test_build(self) -> None:
         room = RoomAlias("channel", "my.domain")
 
         self.assertEqual(room.to_string(), "#channel:my.domain")
 
-    def test_validate(self):
+    def test_validate(self) -> None:
         id_string = "#test:domain,test"
         self.assertFalse(RoomAlias.is_valid(id_string))
 
 
 class MapUsernameTestCase(unittest.TestCase):
-    def testPassThrough(self):
+    def test_pass_througuh(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
 
-    def testUpperCase(self):
+    def test_upper_case(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
         self.assertEqual(
             map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
             "t_e_s_t__1234",
         )
 
-    def testSymbols(self):
+    def test_symbols(self) -> None:
         self.assertEqual(
             map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
         )
 
-    def testLeadingUnderscore(self):
+    def test_leading_underscore(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
 
-    def testNonAscii(self):
+    def test_non_ascii(self) -> None:
         # this should work with either a unicode or a bytes
         self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
         self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index e62ebcc6a5..e5dae670a7 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -20,12 +20,13 @@ import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
 from unittest.mock import Mock
 
 import attr
 import zope.interface
 
+from twisted.internet.interfaces import IProtocol
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
 from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
 
 from synapse.types import JsonDict
 
+if TYPE_CHECKING:
+    from sys import UnraisableHookArgs
+
 TV = TypeVar("TV")
 
 
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
     unraisable_exceptions = []
     orig_unraisablehook = sys.unraisablehook
 
-    def unraisablehook(unraisable):
+    def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
         unraisable_exceptions.append(unraisable.exc_value)
 
-    def cleanup():
+    def cleanup() -> None:
         """
         A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
         """
         sys.unraisablehook = orig_unraisablehook
         if unraisable_exceptions:
-            raise unraisable_exceptions.pop()
+            exc = unraisable_exceptions.pop()
+            assert exc is not None
+            raise exc
 
     sys.unraisablehook = unraisablehook
 
     return cleanup
 
 
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+    return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
     # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
+    async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
         if raises:
             raise raises
         return return_value
@@ -125,14 +133,14 @@ class FakeResponse:  # type: ignore[misc]
     headers: Headers = attr.Factory(Headers)
 
     @property
-    def phrase(self):
+    def phrase(self) -> bytes:
         return RESPONSES.get(self.code, b"Unknown Status")
 
     @property
-    def length(self):
+    def length(self) -> int:
         return len(self.body)
 
-    def deliverBody(self, protocol):
+    def deliverBody(self, protocol: IProtocol) -> None:
         protocol.dataReceived(self.body)
         protocol.connectionLost(Failure(ResponseDone()))
 
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 8027c7a856..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import synapse.server
 from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
     membership: str,
     target: Optional[str] = None,
     extra_content: Optional[dict] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> EventBase:
     """Inject a membership event into a room."""
     if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[List[str]] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> EventBase:
     """Inject a generic event into a room
 
@@ -82,7 +82,7 @@ async def create_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[List[str]] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> Tuple[EventBase, EventContext]:
     if room_version is None:
         room_version = await hs.get_datastores().main.get_room_version_id(
@@ -92,8 +92,13 @@ async def create_event(
     builder = hs.get_event_builder_factory().for_room_version(
         KNOWN_ROOM_VERSIONS[room_version], kwargs
     )
-    event, context = await hs.get_event_creation_handler().create_new_client_event(
+    (
+        event,
+        unpersisted_context,
+    ) = await hs.get_event_creation_handler().create_new_client_event(
         builder, prev_event_ids=prev_event_ids
     )
 
+    context = await unpersisted_context.persist(event)
+
     return event, context
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
index e878af5f12..189c697efb 100644
--- a/tests/test_utils/html_parsers.py
+++ b/tests/test_utils/html_parsers.py
@@ -13,13 +13,13 @@
 # limitations under the License.
 
 from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
 
 
 class TestHtmlParser(HTMLParser):
     """A generic HTML page parser which extracts useful things from the HTML"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         super().__init__()
 
         # a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
                 assert input_name
                 self.hiddens[input_name] = attr_dict["value"]
 
-    def error(_, message):
+    def error(self, message: str) -> NoReturn:
         raise AssertionError(message)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 304c7b98c5..b522163a34 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
 
     tx_log = twisted.logger.Logger()
 
-    def emit(self, record):
+    def emit(self, record: logging.LogRecord) -> None:
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace("warning", "warn")
         self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
         )
 
 
-def setup_logging():
+def setup_logging() -> None:
     """Configure the python logging appropriately for the tests.
 
     (Logs will end up in _trial_temp.)
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
 
 
 import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
 from unittest.mock import Mock, patch
 from urllib.parse import parse_qs
 
@@ -77,14 +77,14 @@ class FakeOidcServer:
 
         self._id_token_overrides: Dict[str, Any] = {}
 
-    def reset_mocks(self):
+    def reset_mocks(self) -> None:
         self.request.reset_mock()
         self.get_jwks_handler.reset_mock()
         self.get_metadata_handler.reset_mock()
         self.get_userinfo_handler.reset_mock()
         self.post_token_handler.reset_mock()
 
-    def patch_homeserver(self, hs: HomeServer):
+    def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
         """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
 
         This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
 
         return self._sign(logout_token)
 
-    def id_token_override(self, overrides: dict):
+    def id_token_override(self, overrides: dict) -> ContextManager[dict]:
         """Temporarily patch the ID token generated by the token endpoint."""
         return patch.object(self, "_id_token_overrides", overrides)
 
@@ -247,7 +247,7 @@ class FakeOidcServer:
         metadata: bool = False,
         token: bool = False,
         userinfo: bool = False,
-    ):
+    ) -> ContextManager[Dict[str, Mock]]:
         """A context which makes a set of endpoints return a 500 error.
 
         Args:
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index d0b9ad5454..36d6b37aa4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -175,9 +175,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
         self.get_success(
             self._storage_controllers.persistence.persist_event(event, context)
         )
@@ -202,9 +203,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         self.get_success(
             self._storage_controllers.persistence.persist_event(event, context)
@@ -226,9 +228,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
+        context = self.get_success(unpersisted_context.persist(event))
 
         self.get_success(
             self._storage_controllers.persistence.persist_event(event, context)
@@ -258,7 +261,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
 
 class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
-    def test_out_of_band_invite_rejection(self):
+    def test_out_of_band_invite_rejection(self) -> None:
         # this is where we have received an invite event over federation, and then
         # rejected it.
         invite_pdu = {
diff --git a/tests/unittest.py b/tests/unittest.py
index fa92dd94eb..68e59a88dc 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase):
 
                 # This has to be a function and not just a Mock, because
                 # `self.helper.auth_user_id` is temporarily reassigned in some tests
-                async def get_requester(*args, **kwargs) -> Requester:
+                async def get_requester(*args: Any, **kwargs: Any) -> Requester:
                     assert self.helper.auth_user_id is not None
                     return create_requester(
                         user_id=UserID.from_string(self.helper.auth_user_id),
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..15fabbc2d0 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -335,6 +335,9 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
         },
     )
 
-    event, context = await event_creation_handler.create_new_client_event(builder)
+    event, unpersisted_context = await event_creation_handler.create_new_client_event(
+        builder
+    )
+    context = await unpersisted_context.persist(event)
 
     await persistence_store.persist_event(event, context)