summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_appservice.py2
-rw-r--r--tests/handlers/test_cas.py8
-rw-r--r--tests/handlers/test_e2e_keys.py55
-rw-r--r--tests/handlers/test_federation.py18
-rw-r--r--tests/handlers/test_federation_event.py6
-rw-r--r--tests/handlers/test_message.py11
-rw-r--r--tests/handlers/test_password_providers.py2
-rw-r--r--tests/handlers/test_register.py10
-rw-r--r--tests/handlers/test_saml.py14
-rw-r--r--tests/handlers/test_typing.py12
-rw-r--r--tests/handlers/test_user_directory.py33
11 files changed, 93 insertions, 78 deletions
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index a7495ab21a..9014e60577 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
 
         # Mock out application services, and allow defining our own in tests
         self._services: List[ApplicationService] = []
-        self.hs.get_datastores().main.get_app_services = Mock(
+        self.hs.get_datastores().main.get_app_services = Mock(  # type: ignore[assignment]
             return_value=self._services
         )
 
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 2733719d82..63aad0d10c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         cas_response = CasResponse("test_user", {})
         request = _mock_request()
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # Map a user via SSO.
         cas_response = CasResponse("test_user", {})
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         cas_response = CasResponse("föö", {})
         request = _mock_request()
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # The response doesn't have the proper userGroup or department.
         cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 95698bc275..6b4cba65d0 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
 
         # we should now have an unused alg1 key
-        res = self.get_success(
+        fallback_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, ["alg1"])
+        self.assertEqual(fallback_res, ["alg1"])
 
         # claiming an OTK when no OTKs are available should return the fallback
         # key
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
         # we shouldn't have any unused fallback keys again
-        res = self.get_success(
+        unused_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, [])
+        self.assertEqual(unused_res, [])
 
         # claiming an OTK again should return the same fallback key
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        unused_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, [])
+        self.assertEqual(unused_res, [])
 
         # uploading a new fallback key should result in an unused fallback key
         self.get_success(
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        unused_res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
-        self.assertEqual(res, ["alg1"])
+        self.assertEqual(unused_res, ["alg1"])
 
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
         )
 
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
         )
 
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        res = self.get_success(
+        claim_res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res,
+            claim_res,
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
 
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         # upload two device keys, which will be signed later by the self-signing key
-        device_key_1 = {
+        device_key_1: JsonDict = {
             "user_id": local_user,
             "device_id": "abc",
             "algorithms": [
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
             "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
         }
-        device_key_2 = {
+        device_key_2: JsonDict = {
             "user_id": local_user,
             "device_id": "def",
             "algorithms": [
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         }
         self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
+        device_handler = self.hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
         e = self.get_failure(
-            self.hs.get_device_handler().check_device_registered(
+            device_handler.check_device_registered(
                 user_id=local_user,
                 device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
                 initial_device_display_name="new display name",
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         device_id = "xyz"
         # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
         device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
-        device_key = {
+        device_key: JsonDict = {
             "user_id": local_user,
             "device_id": device_id,
             "algorithms": [
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 
         # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
         master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
-        master_key = {
+        master_key: JsonDict = {
             "user_id": local_user,
             "usage": ["master"],
             "keys": {"ed25519:" + master_pubkey: master_pubkey},
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # the first user
         other_user = "@otherboris:" + self.hs.hostname
         other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
-        other_master_key = {
+        other_master_key: JsonDict = {
             # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
             "user_id": other_user,
             "usage": ["master"],
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_client_keys = mock.Mock(
+        self.hs.get_federation_client().query_client_keys = mock.Mock(  # type: ignore[assignment]
             return_value=make_awaitable(
                 {
                     "device_keys": {remote_user_id: {}},
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
-        self.hs.get_federation_client().query_user_devices = mock.Mock(
+        self.hs.get_federation_client().query_user_devices = mock.Mock(  # type: ignore[assignment]
             return_value=make_awaitable(
                 {
                     "user_id": remote_user_id,
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 5868eb2da7..bf0862ed54 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         # We mock out the FederationClient.backfill method, to pretend that a remote
         # server has returned our fake event.
         federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
-        self.hs.get_federation_client().backfill = federation_client_backfill_mock
+        self.hs.get_federation_client().backfill = federation_client_backfill_mock  # type: ignore[assignment]
 
         # We also mock the persist method with a side effect of itself. This allows us
         # to track when it has been called while preserving its function.
         persist_events_and_notify_mock = Mock(
             side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
         )
-        self.hs.get_federation_event_handler().persist_events_and_notify = (
+        self.hs.get_federation_event_handler().persist_events_and_notify = (  # type: ignore[assignment]
             persist_events_and_notify_mock
         )
 
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
         ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
             # Start the partial state sync.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Try to start another partial state sync.
             # Nothing should happen.
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
 
             # The next attempt to start the partial state sync should work.
             is_partial_state = True
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
     def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
         ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
             # Start the partial state sync.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             self.assertEqual(mock_sync_partial_state_room.call_count, 1)
 
             # Start the partial state sync again.
-            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
             # Deduplicate another partial state sync.
-            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
             self.assertEqual(mock_sync_partial_state_room.call_count, 2)
 
             # Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             self.assertEqual(mock_sync_partial_state_room.call_count, 3)
             mock_sync_partial_state_room.assert_called_with(
                 initial_destination="hs3",
-                other_destinations=["hs2"],
+                other_destinations={"hs2"},
                 room_id="room_id",
             )
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 70ea4d15d4..c067e5bfe3 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
+from synapse.state import StateResolutionStore
 from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
         if prev_exists_as_outlier:
             prev_event.internal_metadata.outlier = True
             persistence = self.hs.get_storage_controllers().persistence
+            assert persistence is not None
             self.get_success(
                 persistence.persist_event(
                     prev_event,
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                         bert_member_event.event_id: bert_member_event,
                         rejected_kick_event.event_id: rejected_kick_event,
                     },
-                    state_res_store=main_store,
+                    state_res_store=StateResolutionStore(main_store),
                 )
             ),
             [bert_member_event.event_id, rejected_kick_event.event_id],
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
                         rejected_power_levels_event.event_id,
                     ],
                     event_map={},
-                    state_res_store=main_store,
+                    state_res_store=StateResolutionStore(main_store),
                     full_conflicted_set=set(),
                 )
             ),
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..69d384442f 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = self.hs.get_event_creation_handler()
-        self._persist_event_storage_controller = (
-            self.hs.get_storage_controllers().persistence
-        )
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self._persist_event_storage_controller = persistence
 
         self.user_id = self.register_user("tester", "foobar")
         self.access_token = self.login("tester", "foobar")
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
-        self.info = self.get_success(
+        info = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(
                 self.access_token,
             )
         )
-        self.token_id = self.info.token_id
+        assert info is not None
+        self.token_id = info.token_id
 
         self.requester = create_requester(self.user_id, access_token_id=self.token_id)
 
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 0916de64f5..aa91bc0a3d 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             username: The username to use for the test.
             registration: Whether to test with registration URLs.
         """
-        self.hs.get_identity_handler().send_threepid_validation = Mock(
+        self.hs.get_identity_handler().send_threepid_validation = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(0),
         )
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index b9332d97dc..782ef09cf4 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_not_blocked(self) -> None:
-        self.store.count_monthly_users = Mock(
+        self.store.count_monthly_users = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
         )
         # Ensure does not throw exception
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
         room_alias_str = "#room:test"
 
-        self.store.count_real_users = Mock(return_value=make_awaitable(1))
+        self.store.count_real_users = Mock(return_value=make_awaitable(1))  # type: ignore[assignment]
         self.store.is_real_user = Mock(return_value=make_awaitable(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
         self,
     ) -> None:
-        self.store.count_real_users = Mock(return_value=make_awaitable(2))
+        self.store.count_real_users = Mock(return_value=make_awaitable(2))  # type: ignore[assignment]
         self.store.is_real_user = Mock(return_value=make_awaitable(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly not federated.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertFalse(room["federatable"])
         self.assertFalse(room["public"])
         self.assertEqual(room["join_rules"], "public")
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly a public room.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertEqual(room["join_rules"], "public")
 
         # Both users should be in the room.
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly a private room.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertFalse(room["public"])
         self.assertEqual(room["join_rules"], "invite")
         self.assertEqual(room["guest_access"], "can_join")
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
         # Ensure the room is properly a private room.
         room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        assert room is not None
         self.assertFalse(room["public"])
         self.assertEqual(room["join_rules"], "invite")
         self.assertEqual(room["guest_access"], "can_join")
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 9b1b8b9f13..b5c772a7ae 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # mock out the error renderer too
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
         request = _mock_request()
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler and error renderer
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
         sso_handler = self.hs.get_sso_handler()
-        sso_handler.render_error = Mock(return_value=None)
+        sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
 
         # register a user to occupy the first-choice MXID
         store = self.hs.get_datastores().main
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
 
         # The response doesn't have the proper userGroup or department.
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 1fe9563c98..94518a7196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
 
         # we mock out the federation client too
-        mock_federation_client = Mock(spec=["put_json"])
-        mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+        self.mock_federation_client = Mock(spec=["put_json"])
+        self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
 
         # the tests assume that we are starting at unix time 1000
         reactor.pump((1000,))
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.mock_hs_notifier = Mock()
         hs = self.setup_test_homeserver(
             notifier=self.mock_hs_notifier,
-            federation_http_client=mock_federation_client,
+            federation_http_client=self.mock_federation_client,
             keyring=mock_keyring,
             replication_streams={},
         )
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        put_json = self.hs.get_federation_http_client().put_json
-        put_json.assert_called_once_with(
+        self.mock_federation_client.put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
             data=_expect_edu_transaction(
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        put_json = self.hs.get_federation_http_client().put_json
-        put_json.assert_called_once_with(
+        self.mock_federation_client.put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
             data=_expect_edu_transaction(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index e9be5fb504..f65a68b9c2 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Tuple
+from typing import Any, Tuple
 from unittest.mock import Mock, patch
 from urllib.parse import quote
 
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
 from synapse.rest.client import login, register, room, user_directory
 from synapse.server import HomeServer
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import create_requester
+from synapse.types import UserProfile, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import override_config
 
 
+# A spam checker which doesn't implement anything, so create a bare object.
+class UselessSpamChecker:
+    def __init__(self, config: Any):
+        pass
+
+
 class UserDirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the UserDirectoryHandler.
 
@@ -773,7 +779,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 1)
 
-        async def allow_all(user_profile: ProfileInfo) -> bool:
+        async def allow_all(user_profile: UserProfile) -> bool:
             # Allow all users.
             return False
 
@@ -787,7 +793,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
 
         # Configure a spam checker that filters all users.
-        async def block_all(user_profile: ProfileInfo) -> bool:
+        async def block_all(user_profile: UserProfile) -> bool:
             # All users are spammy.
             return True
 
@@ -797,6 +803,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 0)
 
+    @override_config(
+        {
+            "spam_checker": {
+                "module": "tests.handlers.test_user_directory.UselessSpamChecker"
+            }
+        }
+    )
     def test_legacy_spam_checker(self) -> None:
         """
         A spam checker without the expected method should be ignored.
@@ -825,11 +838,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
         self.assertEqual(public_users, set())
 
-        # Configure a spam checker.
-        spam_checker = self.hs.get_spam_checker()
-        # The spam checker doesn't need any methods, so create a bare object.
-        spam_checker.spam_checker = object()
-
         # We get one search result when searching for user2 by user1.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 1)
@@ -954,10 +962,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         context = self.get_success(unpersisted_context.persist(event))
-
-        self.get_success(
-            self.hs.get_storage_controllers().persistence.persist_event(event, context)
-        )
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self.get_success(persistence.persist_event(event, context))
 
     def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
         """We've chosen to simplify the user directory's implementation by