From 42aea0d8af1556473b4f31f78d9facb448230a1f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 14 Feb 2023 14:03:35 -0500 Subject: Add final type hint to tests.unittest. (#15072) Adds a return type to HomeServerTestCase.make_homeserver and deal with any variables which are no longer Any. --- tests/handlers/test_appservice.py | 2 +- tests/handlers/test_cas.py | 8 ++--- tests/handlers/test_e2e_keys.py | 55 ++++++++++++++++--------------- tests/handlers/test_federation.py | 18 +++++----- tests/handlers/test_federation_event.py | 6 ++-- tests/handlers/test_message.py | 11 ++++--- tests/handlers/test_password_providers.py | 2 +- tests/handlers/test_register.py | 10 ++++-- tests/handlers/test_saml.py | 14 ++++---- tests/handlers/test_typing.py | 12 +++---- tests/handlers/test_user_directory.py | 33 +++++++++++-------- 11 files changed, 93 insertions(+), 78 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index a7495ab21a..9014e60577 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] return_value=self._services ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 2733719d82..63aad0d10c 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] cas_response = CasResponse("test_user", {}) request = _mock_request() @@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # Map a user via SSO. cas_response = CasResponse("test_user", {}) @@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] cas_response = CasResponse("föö", {}) request = _mock_request() @@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. cas_response = CasResponse("test_user", {}) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 95698bc275..6b4cba65d0 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) # we should now have an unused alg1 key - res = self.get_success( + fallback_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, ["alg1"]) + self.assertEqual(fallback_res, ["alg1"]) # claiming an OTK when no OTKs are available should return the fallback # key - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) # we shouldn't have any unused fallback keys again - res = self.get_success( + unused_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, []) + self.assertEqual(unused_res, []) # claiming an OTK again should return the same fallback key - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, ) @@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + unused_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, []) + self.assertEqual(unused_res, []) # uploading a new fallback key should result in an unused fallback key self.get_success( @@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + unused_res = self.get_success( self.store.get_e2e_unused_fallback_key_types(local_user, device_id) ) - self.assertEqual(res, ["alg1"]) + self.assertEqual(unused_res, ["alg1"]) # if the user uploads a one-time key, the next claim should fetch the # one-time key, and then go back to the fallback @@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, ) - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}}, ) @@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) ) - res = self.get_success( + claim_res = self.get_success( self.handler.claim_one_time_keys( {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None ) ) self.assertEqual( - res, + claim_res, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) @@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) # upload two device keys, which will be signed later by the self-signing key - device_key_1 = { + device_key_1: JsonDict = { "user_id": local_user, "device_id": "abc", "algorithms": [ @@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, } - device_key_2 = { + device_key_2: JsonDict = { "user_id": local_user, "device_id": "def", "algorithms": [ @@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): } self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1)) + device_handler = self.hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) e = self.get_failure( - self.hs.get_device_handler().check_device_registered( + device_handler.check_device_registered( user_id=local_user, device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", initial_device_display_name="new display name", @@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): device_id = "xyz" # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" - device_key = { + device_key: JsonDict = { "user_id": local_user, "device_id": device_id, "algorithms": [ @@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" - master_key = { + master_key: JsonDict = { "user_id": local_user, "usage": ["master"], "keys": {"ed25519:" + master_pubkey: master_pubkey}, @@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # the first user other_user = "@otherboris:" + self.hs.hostname other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" - other_master_key = { + other_master_key: JsonDict = { # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI "user_id": other_user, "usage": ["master"], @@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_client_keys = mock.Mock( + self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] return_value=make_awaitable( { "device_keys": {remote_user_id: {}}, @@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_user_devices = mock.Mock( + self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] return_value=make_awaitable( { "user_id": remote_user_id, diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 5868eb2da7..bf0862ed54 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # We mock out the FederationClient.backfill method, to pretend that a remote # server has returned our fake event. federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) - self.hs.get_federation_client().backfill = federation_client_backfill_mock + self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] # We also mock the persist method with a side effect of itself. This allows us # to track when it has been called while preserving its function. persist_events_and_notify_mock = Mock( side_effect=self.hs.get_federation_event_handler().persist_events_and_notify ) - self.hs.get_federation_event_handler().persist_events_and_notify = ( + self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment] persist_events_and_notify_mock ) @@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): # Start the partial state sync. - fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) # Try to start another partial state sync. # Nothing should happen. - fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) # End the partial state sync @@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): # The next attempt to start the partial state sync should work. is_partial_state = True - fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 2) def test_partial_state_room_sync_restart(self) -> None: @@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): # Start the partial state sync. - fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 1) # Fail the partial state sync. @@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(mock_sync_partial_state_room.call_count, 1) # Start the partial state sync again. - fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 2) # Deduplicate another partial state sync. - fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id") self.assertEqual(mock_sync_partial_state_room.call_count, 2) # Fail the partial state sync. @@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(mock_sync_partial_state_room.call_count, 3) mock_sync_partial_state_room.assert_called_with( initial_destination="hs3", - other_destinations=["hs2"], + other_destinations={"hs2"}, room_id="room_id", ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 70ea4d15d4..c067e5bfe3 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer +from synapse.state import StateResolutionStore from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort from synapse.types import JsonDict from synapse.util import Clock @@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): if prev_exists_as_outlier: prev_event.internal_metadata.outlier = True persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None self.get_success( persistence.persist_event( prev_event, @@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): bert_member_event.event_id: bert_member_event, rejected_kick_event.event_id: rejected_kick_event, }, - state_res_store=main_store, + state_res_store=StateResolutionStore(main_store), ) ), [bert_member_event.event_id, rejected_kick_event.event_id], @@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): rejected_power_levels_event.event_id, ], event_map={}, - state_res_store=main_store, + state_res_store=StateResolutionStore(main_store), full_conflicted_set=set(), ) ), diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index c4727ab917..69d384442f 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = self.hs.get_event_creation_handler() - self._persist_event_storage_controller = ( - self.hs.get_storage_controllers().persistence - ) + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self._persist_event_storage_controller = persistence self.user_id = self.register_user("tester", "foobar") self.access_token = self.login("tester", "foobar") self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) - self.info = self.get_success( + info = self.get_success( self.hs.get_datastores().main.get_user_by_access_token( self.access_token, ) ) - self.token_id = self.info.token_id + assert info is not None + self.token_id = info.token_id self.requester = create_requester(self.user_id, access_token_id=self.token_id) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 0916de64f5..aa91bc0a3d 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): username: The username to use for the test. registration: Whether to test with registration URLs. """ - self.hs.get_identity_handler().send_threepid_validation = Mock( + self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] return_value=make_awaitable(0), ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index b9332d97dc..782ef09cf4 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self) -> None: - self.store.count_monthly_users = Mock( + self.store.count_monthly_users = Mock( # type: ignore[assignment] return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=make_awaitable(1)) + self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( self, ) -> None: - self.store.count_real_users = Mock(return_value=make_awaitable(2)) + self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) @@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly not federated. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertFalse(room["federatable"]) self.assertFalse(room["public"]) self.assertEqual(room["join_rules"], "public") @@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a public room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertEqual(room["join_rules"], "public") # Both users should be in the room. @@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a private room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertFalse(room["public"]) self.assertEqual(room["join_rules"], "invite") self.assertEqual(room["guest_access"], "can_join") @@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Ensure the room is properly a private room. room = self.get_success(self.store.get_room_with_stats(room_id["room_id"])) + assert room is not None self.assertFalse(room["public"]) self.assertEqual(room["join_rules"], "invite") self.assertEqual(room["guest_access"], "can_join") diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 9b1b8b9f13..b5c772a7ae 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) @@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # Map a user via SSO. saml_response = FakeAuthnResponse( @@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # mock out the error renderer too sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) + sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) request = _mock_request() @@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler and error renderer auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) + sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] # register a user to occupy the first-choice MXID store = self.hs.get_datastores().main @@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() + auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] # The response doesn't have the proper userGroup or department. saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1fe9563c98..94518a7196 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_keyring.verify_json_for_server.return_value = make_awaitable(True) # we mock out the federation client too - mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) + self.mock_federation_client = Mock(spec=["put_json"]) + self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) @@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( notifier=self.mock_hs_notifier, - federation_http_client=mock_federation_client, + federation_http_client=self.mock_federation_client, keyring=mock_keyring, replication_streams={}, ) @@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - put_json = self.hs.get_federation_http_client().put_json - put_json.assert_called_once_with( + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( @@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) - put_json = self.hs.get_federation_http_client().put_json - put_json.assert_called_once_with( + self.mock_federation_client.put_json.assert_called_once_with( "farm", path="/_matrix/federation/v1/send/1000000", data=_expect_edu_transaction( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index e9be5fb504..f65a68b9c2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Any, Tuple from unittest.mock import Mock, patch from urllib.parse import quote @@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client import login, register, room, user_directory from synapse.server import HomeServer from synapse.storage.roommember import ProfileInfo -from synapse.types import create_requester +from synapse.types import UserProfile, create_requester from synapse.util import Clock from tests import unittest @@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config +# A spam checker which doesn't implement anything, so create a bare object. +class UselessSpamChecker: + def __init__(self, config: Any): + pass + + class UserDirectoryTestCase(unittest.HomeserverTestCase): """Tests the UserDirectoryHandler. @@ -773,7 +779,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) - async def allow_all(user_profile: ProfileInfo) -> bool: + async def allow_all(user_profile: UserProfile) -> bool: # Allow all users. return False @@ -787,7 +793,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - async def block_all(user_profile: ProfileInfo) -> bool: + async def block_all(user_profile: UserProfile) -> bool: # All users are spammy. return True @@ -797,6 +803,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) + @override_config( + { + "spam_checker": { + "module": "tests.handlers.test_user_directory.UselessSpamChecker" + } + } + ) def test_legacy_spam_checker(self) -> None: """ A spam checker without the expected method should be ignored. @@ -825,11 +838,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) self.assertEqual(public_users, set()) - # Configure a spam checker. - spam_checker = self.hs.get_spam_checker() - # The spam checker doesn't need any methods, so create a bare object. - spam_checker.spam_checker = object() - # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 1) @@ -954,10 +962,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) context = self.get_success(unpersisted_context.persist(event)) - - self.get_success( - self.hs.get_storage_controllers().persistence.persist_event(event, context) - ) + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self.get_success(persistence.persist_event(event, context)) def test_local_user_leaving_room_remains_in_user_directory(self) -> None: """We've chosen to simplify the user directory's implementation by -- cgit 1.4.1