From 30509a1010f10bc7924146cac57571c4b24914d7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Feb 2023 16:29:49 -0500 Subject: Add more missing type hints to tests. (#15028) --- tests/handlers/test_oidc.py | 4 +- tests/scripts/test_new_matrix_user.py | 25 ++++--- tests/server_notices/test_consent.py | 14 ++-- .../test_resource_limits_server_notices.py | 35 ++++++---- tests/test_federation.py | 80 ++++++++++++---------- tests/test_utils/__init__.py | 26 ++++--- tests/test_utils/event_injection.py | 8 +-- tests/test_utils/html_parsers.py | 6 +- tests/test_utils/logging_setup.py | 4 +- tests/test_utils/oidc.py | 10 +-- tests/test_visibility.py | 2 +- tests/unittest.py | 2 +- 12 files changed, 123 insertions(+), 93 deletions(-) (limited to 'tests') diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index adddbd002f..951caaa6b3 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver() self.hs_patcher = self.fake_server.patch_homeserver(hs=hs) - self.hs_patcher.start() + self.hs_patcher.start() # type: ignore[attr-defined] self.handler = hs.get_oidc_handler() self.provider = self.handler._providers["oidc"] @@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def tearDown(self) -> None: - self.hs_patcher.stop() + self.hs_patcher.stop() # type: ignore[attr-defined] return super().tearDown() def reset_mocks(self) -> None: diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py index 22f99c6ab1..3285f2433c 100644 --- a/tests/scripts/test_new_matrix_user.py +++ b/tests/scripts/test_new_matrix_user.py @@ -12,29 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from unittest.mock import Mock, patch from synapse._scripts.register_new_matrix_user import request_registration +from synapse.types import JsonDict from tests.unittest import TestCase class RegisterTestCase(TestCase): - def test_success(self): + def test_success(self) -> None: """ The script will fetch a nonce, and then generate a MAC with it, and then post that MAC. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r - def post(url, json=None, verify=None): + def post( + url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + ) -> Mock: # Make sure we are sent the correct info + assert json is not None self.assertEqual(json["username"], "user") self.assertEqual(json["password"], "pass") self.assertEqual(json["nonce"], "a") @@ -70,12 +74,12 @@ class RegisterTestCase(TestCase): # sys.exit shouldn't have been called. self.assertEqual(err_code, []) - def test_failure_nonce(self): + def test_failure_nonce(self) -> None: """ If the script fails to fetch a nonce, it throws an error and quits. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 404 r.reason = "Not Found" @@ -107,20 +111,23 @@ class RegisterTestCase(TestCase): self.assertIn("ERROR! Received 404 Not Found", out) self.assertNotIn("Success!", out) - def test_failure_post(self): + def test_failure_post(self) -> None: """ The script will fetch a nonce, and then if the final POST fails, will report an error and quit. """ - def get(url, verify=None): + def get(url: str, verify: Optional[bool] = None) -> Mock: r = Mock() r.status_code = 200 r.json = lambda: {"nonce": "a"} return r - def post(url, json=None, verify=None): + def post( + url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None + ) -> Mock: # Make sure we are sent the correct info + assert json is not None self.assertEqual(json["username"], "user") self.assertEqual(json["password"], "pass") self.assertEqual(json["nonce"], "a") diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 58b399a043..6540ed53f1 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -14,8 +14,12 @@ import os +from twisted.test.proto_helpers import MemoryReactor + import synapse.rest.admin from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest @@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): room.register_servlets, ] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: tmpdir = self.mktemp() os.mkdir(tmpdir) @@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): "room_name": "Server Notices", } - hs = self.setup_test_homeserver(config=config) - - return hs + return self.setup_test_homeserver(config=config) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.user_id = self.register_user("bob", "abc123") self.access_token = self.login("bob", "abc123") - def test_get_sync_message(self): + def test_get_sync_message(self) -> None: """ When user consent server notices are enabled, a sync will cause a notice to fire (in a room which the user is invited to). The notice contains diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index dadc6efcbf..5b76383d76 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -24,6 +24,7 @@ from synapse.server import HomeServer from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -33,7 +34,7 @@ from tests.utils import default_config class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def default_config(self): + def default_config(self) -> JsonDict: config = default_config("test") config.update( @@ -86,18 +87,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] @override_config({"hs_disabled": True}) - def test_maybe_send_server_notice_disabled_hs(self): + def test_maybe_send_server_notice_disabled_hs(self) -> None: """If the HS is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @override_config({"limit_usage_by_mau": False}) - def test_maybe_send_server_notice_to_user_flag_off(self): + def test_maybe_send_server_notice_to_user_flag_off(self) -> None: """If mau limiting is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): + def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: """Test when user has blocked notice, but should have it removed""" self._rlsn._auth_blocking.check_auth_blocking = Mock( @@ -114,7 +115,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() self._send_notice.assert_called_once() - def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): + def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None: """ Test when user has blocked notice, but notice ought to be there (NOOP) """ @@ -134,7 +135,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_add_blocked_notice(self): + def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None: """ Test when user does not have blocked notice, but should have one """ @@ -147,7 +148,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): # Would be better to check contents, but 2 calls == set blocking event self.assertEqual(self._send_notice.call_count, 2) - def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self): + def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None: """ Test when user does not have blocked notice, nor should they (NOOP) """ @@ -159,7 +160,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() - def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self): + def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None: """ Test when user is not part of the MAU cohort - this should not ever happen - but ... @@ -175,7 +176,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() @override_config({"mau_limit_alerting": False}) - def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): + def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked( + self, + ) -> None: """ Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room @@ -191,7 +194,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.assertEqual(self._send_notice.call_count, 0) @override_config({"mau_limit_alerting": False}) - def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): + def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None: """ Test that when a server is disabled, that MAU limit alerting is ignored. """ @@ -207,7 +210,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.assertEqual(self._send_notice.call_count, 2) @override_config({"mau_limit_alerting": False}) - def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): + def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked( + self, + ) -> None: """ When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. @@ -242,7 +247,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): sync.register_servlets, ] - def default_config(self): + def default_config(self) -> JsonDict: c = super().default_config() c["server_notices"] = { "system_mxid_localpart": "server", @@ -270,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" - def test_server_notice_only_sent_once(self): + def test_server_notice_only_sent_once(self) -> None: self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) self.store.user_last_seen_monthly_active = Mock( @@ -306,7 +311,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.assertEqual(count, 1) - def test_no_invite_without_notice(self): + def test_no_invite_without_notice(self) -> None: """Tests that a user doesn't get invited to a server notices room without a server notice being sent. @@ -328,7 +333,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): m.assert_called_once_with(user_id) - def test_invite_with_notice(self): + def test_invite_with_notice(self) -> None: """Tests that, if the MAU limit is hit, the server notices user invites each user to a room in which it has sent a notice. """ diff --git a/tests/test_federation.py b/tests/test_federation.py index 80e5c590d8..ddb43c8c98 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -12,53 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union from unittest.mock import Mock from twisted.internet.defer import succeed +from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import FederationError from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.federation.federation_base import event_from_pdu_json +from synapse.http.types import QueryParams from synapse.logging.context import LoggingContext -from synapse.types import UserID, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination from tests import unittest -from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): - def setUp(self): - + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock() - self.reactor = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.reactor) - self.homeserver = setup_test_homeserver( - self.addCleanup, - federation_http_client=self.http_client, - clock=self.hs_clock, - reactor=self.reactor, - ) + return self.setup_test_homeserver(federation_http_client=self.http_client) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: user_id = UserID("us", "test") our_user = create_requester(user_id) - room_creator = self.homeserver.get_room_creation_handler() + room_creator = self.hs.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( our_user, room_creator._presets_dict["public_chat"], ratelimit=False ) )[0]["room_id"] - self.store = self.homeserver.get_datastores().main + self.store = self.hs.get_datastores().main # Figure out what the most recent event is most_recent = self.get_success( - self.homeserver.get_datastores().main.get_latest_event_ids_in_room( - self.room_id - ) + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) )[0] join_event = make_event_from_dict( @@ -78,14 +73,16 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - self.handler = self.homeserver.get_federation_handler() - federation_event_handler = self.homeserver.get_federation_event_handler() + self.handler = self.hs.get_federation_handler() + federation_event_handler = self.hs.get_federation_event_handler() - async def _check_event_auth(origin, event, context): + async def _check_event_auth( + origin: Optional[str], event: EventBase, context: EventContext + ) -> None: pass federation_event_handler._check_event_auth = _check_event_auth - self.client = self.homeserver.get_federation_client() + self.client = self.hs.get_federation_client() self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( lambda dest, pdus, **k: succeed(pdus) ) @@ -104,16 +101,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase): "$join:test.serv", ) - def test_cant_hide_direct_ancestors(self): + def test_cant_hide_direct_ancestors(self) -> None: """ If you send a message, you must be able to provide the direct prev_events that said event references. """ - async def post_json(destination, path, data, headers=None, timeout=0): + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryParams] = None, + ) -> Union[JsonDict, list]: # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} + return {} self.http_client.post_json = post_json @@ -138,7 +144,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - federation_event_handler = self.homeserver.get_federation_event_handler() + federation_event_handler = self.hs.get_federation_event_handler() with LoggingContext("test-context"): failure = self.get_failure( federation_event_handler.on_receive_pdu("test.serv", lying_event), @@ -158,7 +164,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) self.assertEqual(extrem[0], "$join:test.serv") - def test_retry_device_list_resync(self): + def test_retry_device_list_resync(self) -> None: """Tests that device lists are marked as stale if they couldn't be synced, and that stale device lists are retried periodically. """ @@ -171,24 +177,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # When this function is called, increment the number of resync attempts (only if # we're querying devices for the right user ID), then raise a # NotRetryingDestination error to fail the resync gracefully. - def query_user_devices(destination, user_id): + def query_user_devices( + destination: str, user_id: str, timeout: int = 30000 + ) -> JsonDict: if user_id == remote_user_id: self.resync_attempts += 1 raise NotRetryingDestination(0, 0, destination) # Register the mock on the federation client. - federation_client = self.homeserver.get_federation_client() + federation_client = self.hs.get_federation_client() federation_client.query_user_devices = Mock(side_effect=query_user_devices) # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. - store = self.homeserver.get_datastores().main + store = self.hs.get_datastores().main store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. - device_list_updater = self.homeserver.get_device_handler().device_list_updater + device_list_updater = self.hs.get_device_handler().device_list_updater self.get_success( device_list_updater.incoming_device_list_update( origin=remote_origin, @@ -218,7 +226,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.reactor.advance(30) self.assertEqual(self.resync_attempts, 2) - def test_cross_signing_keys_retry(self): + def test_cross_signing_keys_retry(self) -> None: """Tests that resyncing a device list correctly processes cross-signing keys from the remote server. """ @@ -227,7 +235,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" # Register mock device list retrieval on the federation client. - federation_client = self.homeserver.get_federation_client() + federation_client = self.hs.get_federation_client() federation_client.query_user_devices = Mock( return_value=make_awaitable( { @@ -252,7 +260,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) # Resync the device list. - device_handler = self.homeserver.get_device_handler() + device_handler = self.hs.get_device_handler() self.get_success( device_handler.device_list_updater.user_device_resync(remote_user_id), ) @@ -279,7 +287,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): class StripUnsignedFromEventsTestCase(unittest.TestCase): - def test_strip_unauthorized_unsigned_values(self): + def test_strip_unauthorized_unsigned_values(self) -> None: event1 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", @@ -296,7 +304,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): # Make sure unauthorized fields are stripped from unsigned self.assertNotIn("more warez", filtered_event.unsigned) - def test_strip_event_maintains_allowed_fields(self): + def test_strip_event_maintains_allowed_fields(self) -> None: event2 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", @@ -323,7 +331,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase): self.assertIn("invite_room_state", filtered_event2.unsigned) self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) - def test_strip_event_removes_fields_based_on_event_type(self): + def test_strip_event_removes_fields_based_on_event_type(self) -> None: event3 = { "sender": "@baduser:test.serv", "state_key": "@baduser:test.serv", diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index e62ebcc6a5..e5dae670a7 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -20,12 +20,13 @@ import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Awaitable, Callable, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar from unittest.mock import Mock import attr import zope.interface +from twisted.internet.interfaces import IProtocol from twisted.python.failure import Failure from twisted.web.client import ResponseDone from twisted.web.http import RESPONSES @@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse from synapse.types import JsonDict +if TYPE_CHECKING: + from sys import UnraisableHookArgs + TV = TypeVar("TV") @@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]: unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook - def unraisablehook(unraisable): + def unraisablehook(unraisable: "UnraisableHookArgs") -> None: unraisable_exceptions.append(unraisable.exc_value) - def cleanup(): + def cleanup() -> None: """ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. """ sys.unraisablehook = orig_unraisablehook if unraisable_exceptions: - raise unraisable_exceptions.pop() + exc = unraisable_exceptions.pop() + assert exc is not None + raise exc sys.unraisablehook = unraisablehook return cleanup -def simple_async_mock(return_value=None, raises=None) -> Mock: +def simple_async_mock( + return_value: Optional[TV] = None, raises: Optional[Exception] = None +) -> Mock: # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): + async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: if raises: raise raises return return_value @@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc] headers: Headers = attr.Factory(Headers) @property - def phrase(self): + def phrase(self) -> bytes: return RESPONSES.get(self.code, b"Unknown Status") @property - def length(self): + def length(self) -> int: return len(self.body) - def deliverBody(self, protocol): + def deliverBody(self, protocol: IProtocol) -> None: protocol.dataReceived(self.body) protocol.connectionLost(Failure(ResponseDone())) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8027c7a856..1a50c2acf1 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes @@ -32,7 +32,7 @@ async def inject_member_event( membership: str, target: Optional[str] = None, extra_content: Optional[dict] = None, - **kwargs, + **kwargs: Any, ) -> EventBase: """Inject a membership event into a room.""" if target is None: @@ -57,7 +57,7 @@ async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> EventBase: """Inject a generic event into a room @@ -82,7 +82,7 @@ async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> Tuple[EventBase, EventContext]: if room_version is None: room_version = await hs.get_datastores().main.get_room_version_id( diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py index e878af5f12..189c697efb 100644 --- a/tests/test_utils/html_parsers.py +++ b/tests/test_utils/html_parsers.py @@ -13,13 +13,13 @@ # limitations under the License. from html.parser import HTMLParser -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, NoReturn, Optional, Tuple class TestHtmlParser(HTMLParser): """A generic HTML page parser which extracts useful things from the HTML""" - def __init__(self): + def __init__(self) -> None: super().__init__() # a list of links found in the doc @@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser): assert input_name self.hiddens[input_name] = attr_dict["value"] - def error(_, message): + def error(self, message: str) -> NoReturn: raise AssertionError(message) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 304c7b98c5..b522163a34 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler): tx_log = twisted.logger.Logger() - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( @@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler): ) -def setup_logging(): +def setup_logging() -> None: """Configure the python logging appropriately for the tests. (Logs will end up in _trial_temp.) diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py index 1461d23ee8..d555b24255 100644 --- a/tests/test_utils/oidc.py +++ b/tests/test_utils/oidc.py @@ -14,7 +14,7 @@ import json -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, ContextManager, Dict, List, Optional, Tuple from unittest.mock import Mock, patch from urllib.parse import parse_qs @@ -77,14 +77,14 @@ class FakeOidcServer: self._id_token_overrides: Dict[str, Any] = {} - def reset_mocks(self): + def reset_mocks(self) -> None: self.request.reset_mock() self.get_jwks_handler.reset_mock() self.get_metadata_handler.reset_mock() self.get_userinfo_handler.reset_mock() self.post_token_handler.reset_mock() - def patch_homeserver(self, hs: HomeServer): + def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]: """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``. This patch should be used whenever the HS is expected to perform request to the @@ -188,7 +188,7 @@ class FakeOidcServer: return self._sign(logout_token) - def id_token_override(self, overrides: dict): + def id_token_override(self, overrides: dict) -> ContextManager[dict]: """Temporarily patch the ID token generated by the token endpoint.""" return patch.object(self, "_id_token_overrides", overrides) @@ -247,7 +247,7 @@ class FakeOidcServer: metadata: bool = False, token: bool = False, userinfo: bool = False, - ): + ) -> ContextManager[Dict[str, Mock]]: """A context which makes a set of endpoints return a 500 error. Args: diff --git a/tests/test_visibility.py b/tests/test_visibility.py index d0b9ad5454..875e37988f 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -258,7 +258,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase): - def test_out_of_band_invite_rejection(self): + def test_out_of_band_invite_rejection(self) -> None: # this is where we have received an invite event over federation, and then # rejected it. invite_pdu = { diff --git a/tests/unittest.py b/tests/unittest.py index fa92dd94eb..68e59a88dc 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase): # This has to be a function and not just a Mock, because # `self.helper.auth_user_id` is temporarily reassigned in some tests - async def get_requester(*args, **kwargs) -> Requester: + async def get_requester(*args: Any, **kwargs: Any) -> Requester: assert self.helper.auth_user_id is not None return create_requester( user_id=UserID.from_string(self.helper.auth_user_id), -- cgit 1.4.1