diff options
-rw-r--r-- | changelog.d/12347.misc | 1 | ||||
-rw-r--r-- | mypy.ini | 1 | ||||
-rw-r--r-- | tests/handlers/test_e2e_keys.py | 6 | ||||
-rw-r--r-- | tests/handlers/test_federation.py | 5 | ||||
-rw-r--r-- | tests/handlers/test_oidc.py | 7 | ||||
-rw-r--r-- | tests/handlers/test_user_directory.py | 2 | ||||
-rw-r--r-- | tests/rest/admin/test_media.py | 8 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 15 | ||||
-rw-r--r-- | tests/server.py | 6 | ||||
-rw-r--r-- | tests/storage/databases/main/test_lock.py | 8 | ||||
-rw-r--r-- | tests/storage/test_appservice.py | 1 | ||||
-rw-r--r-- | tests/unittest.py | 85 |
12 files changed, 97 insertions, 48 deletions
diff --git a/changelog.d/12347.misc b/changelog.d/12347.misc new file mode 100644 index 0000000000..1f6f584e6d --- /dev/null +++ b/changelog.d/12347.misc @@ -0,0 +1 @@ +Add type annotations for `tests/unittest.py`. diff --git a/mypy.ini b/mypy.ini index 84e6b8646e..85291099ac 100644 --- a/mypy.ini +++ b/mypy.ini @@ -83,7 +83,6 @@ exclude = (?x) |tests/test_server.py |tests/test_state.py |tests/test_terms_auth.py - |tests/unittest.py |tests/util/caches/test_cached_call.py |tests/util/caches/test_deferred_cache.py |tests/util/caches/test_descriptors.py diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index ac21a28c43..8c74ed1fcf 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res = e.value.code self.assertEqual(res, 400) - res = self.get_success(self.handler.query_local_devices({local_user: None})) - self.assertDictEqual(res, {local_user: {}}) + query_res = self.get_success( + self.handler.query_local_devices({local_user: None}) + ) + self.assertDictEqual(query_res, {local_user: {}}) def test_upload_signatures(self) -> None: """should check signatures that are uploaded""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 0fa5045301..060ba5f517 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -375,7 +375,8 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): member_event.signatures = member_event_dict["signatures"] # Add the new member_event to the StateMap - prev_state_map[ + updated_state_map = dict(prev_state_map) + updated_state_map[ (member_event.type, member_event.state_key) ] = member_event.event_id auth_events.append(member_event) @@ -399,7 +400,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): prev_event_ids=message_event_dict["prev_events"], auth_event_ids=self._event_auth_handler.compute_auth_events( builder, - prev_state_map, + updated_state_map, for_verification=False, ), depth=message_event_dict["depth"], diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 014815db6e..9684120c70 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase): req = Mock(spec=["cookies"]) req.cookies = [] - url = self.get_success( - self.provider.handle_redirect_request(req, b"http://client/redirect") + url = urlparse( + self.get_success( + self.provider.handle_redirect_request(req, b"http://client/redirect") + ) ) - url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) self.assertEqual(url.scheme, auth_endpoint.scheme) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 92012cd6f7..c6e501c7be 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.handler.handle_local_profile_change(regular_user_id, profile_info) ) profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) + assert profile is not None self.assertTrue(profile["display_name"] == display_name) def test_handle_local_profile_change_with_deactivated_user(self) -> None: @@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is in directory profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + assert profile is not None self.assertTrue(profile["display_name"] == display_name) # deactivate user diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 0d47dd0aff..e909e444ac 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): """ media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) # quarantining @@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["quarantined_by"]) # remove from quarantine @@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) def test_quarantine_protected_media(self) -> None: @@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # verify protection media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["safe_from_quarantine"]) # quarantining @@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): # verify that is not in quarantine media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["quarantined_by"]) @@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): """ media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["safe_from_quarantine"]) # protect @@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertTrue(media_info["safe_from_quarantine"]) # unprotect @@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(channel.json_body) media_info = self.get_success(self.store.get_local_media(self.media_id)) + assert media_info is not None self.assertFalse(media_info["safe_from_quarantine"]) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index bef911d5df..0cdf1dec40 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - pushers = self.get_success( - self.store.get_pushers_by({"user_name": "@bob:test"}) + pushers = list( + self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) ) - pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual("@bob:test", pushers[0].user_name) @@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - pushers = self.get_success( - self.store.get_pushers_by({"user_name": "@bob:test"}) + pushers = list( + self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"})) ) - pushers = list(pushers) self.assertEqual(len(pushers), 0) def test_set_password(self) -> None: @@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # is in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + assert profile is not None self.assertTrue(profile["display_name"] == "User") # Deactivate user @@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): user_tuple = self.get_success( self.store.get_user_by_access_token(other_user_token) ) + assert user_tuple is not None token_id = user_tuple.token_id self.get_success( @@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # The user starts off as not shadow-banned. other_user_token = self.login("user", "pass") result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertFalse(result.shadow_banned) channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) @@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertTrue(result.shadow_banned) # Un-shadow-ban the user. @@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is no longer shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + assert result is not None self.assertFalse(result.shadow_banned) diff --git a/tests/server.py b/tests/server.py index 6ce2a17bf4..aaa5ca3e74 100644 --- a/tests/server.py +++ b/tests/server.py @@ -22,7 +22,6 @@ import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( - AnyStr, Callable, Dict, Iterable, @@ -86,6 +85,9 @@ from tests.utils import ( logger = logging.getLogger(__name__) +# the type of thing that can be passed into `make_request` in the headers list +CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] + class TimedOutException(Exception): """ @@ -260,7 +262,7 @@ def make_request( federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 3ac4646969..74c6224eb6 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase): """ # First to acquire this lock, so it should complete lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None # Enter the context manager self.get_success(lock.__aenter__()) @@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase): # We can now acquire the lock again. lock3 = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock3) + assert lock3 is not None self.get_success(lock3.__aenter__()) self.get_success(lock3.__aexit__(None, None, None)) @@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase): """Test that we don't time out locks while they're still active""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None self.get_success(lock.__aenter__()) @@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase): """Test that we time out locks if they're not updated for ages""" lock = self.get_success(self.store.try_acquire_lock("name", "key")) - self.assertIsNotNone(lock) + assert lock is not None self.get_success(lock.__aenter__()) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 08078d38e2..1bf93e79a7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -358,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(service.id, 12, other_events)) txn = self.get_success(self.store.get_oldest_unsent_txn(service)) + assert txn is not None self.assertEqual(service, txn.service) self.assertEqual(10, txn.id) self.assertEqual(events, txn.events) diff --git a/tests/unittest.py b/tests/unittest.py index 5b19065c71..9afa68c164 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -22,10 +22,11 @@ import secrets import time from typing import ( Any, - AnyStr, + Awaitable, Callable, ClassVar, Dict, + Generic, Iterable, List, Optional, @@ -39,6 +40,7 @@ from unittest.mock import Mock, patch import canonicaljson import signedjson.key import unpaddedbase64 +from typing_extensions import Protocol from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure @@ -49,7 +51,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse import events -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION @@ -70,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree -from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver +from tests.server import ( + CustomHeaderType, + FakeChannel, + get_clock, + make_request, + setup_test_homeserver, +) from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -78,6 +86,17 @@ from tests.utils import default_config, setupdb setupdb() setup_logging() +TV = TypeVar("TV") +_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) + + +class _TypedFailure(Generic[_ExcType], Protocol): + """Extension to twisted.Failure, where the 'value' has a certain type.""" + + @property + def value(self) -> _ExcType: + ... + def around(target): """A CLOS-style 'around' modifier, which wraps the original method of the @@ -276,6 +295,7 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: + assert self.helper.auth_user_id is not None # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( @@ -288,6 +308,7 @@ class HomeserverTestCase(TestCase): ) async def get_user_by_access_token(token=None, allow_guest=False): + assert self.helper.auth_user_id is not None return { "user": UserID.from_string(self.helper.auth_user_id), "token_id": token_id, @@ -295,6 +316,7 @@ class HomeserverTestCase(TestCase): } async def get_user_by_req(request, allow_guest=False, rights="access"): + assert self.helper.auth_user_id is not None return create_requester( UserID.from_string(self.helper.auth_user_id), token_id, @@ -311,7 +333,7 @@ class HomeserverTestCase(TestCase): ) if self.needs_threadpool: - self.reactor.threadpool = ThreadPool() + self.reactor.threadpool = ThreadPool() # type: ignore[assignment] self.addCleanup(self.reactor.threadpool.stop) self.reactor.threadpool.start() @@ -426,7 +448,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """ @@ -511,30 +533,36 @@ class HomeserverTestCase(TestCase): return hs - def pump(self, by=0.0): + def pump(self, by: float = 0.0) -> None: """ Pump the reactor enough that Deferreds will fire. """ self.reactor.pump([by] * 100) - def get_success(self, d, by=0.0): - deferred: Deferred[TV] = ensureDeferred(d) + def get_success( + self, + d: Awaitable[TV], + by: float = 0.0, + ) -> TV: + deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] self.pump(by=by) return self.successResultOf(deferred) - def get_failure(self, d, exc): + def get_failure( + self, d: Awaitable[Any], exc: Type[_ExcType] + ) -> _TypedFailure[_ExcType]: """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ - deferred: Deferred[Any] = ensureDeferred(d) + deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type] self.pump() return self.failureResultOf(deferred, exc) - def get_success_or_raise(self, d, by=0.0): + def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV: """Drive deferred to completion and return result or raise exception on failure. """ - deferred: Deferred[TV] = ensureDeferred(d) + deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] results: list = [] deferred.addBoth(results.append) @@ -642,11 +670,11 @@ class HomeserverTestCase(TestCase): def login( self, - username, - password, - device_id=None, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, - ): + username: str, + password: str, + device_id: Optional[str] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, + ) -> str: """ Log in a user, and get an access token. Requires the Login API be registered. @@ -668,18 +696,22 @@ class HomeserverTestCase(TestCase): return access_token def create_and_send_event( - self, room_id, user, soft_failed=False, prev_event_ids=None - ): + self, + room_id: str, + user: UserID, + soft_failed: bool = False, + prev_event_ids: Optional[List[str]] = None, + ) -> str: """ Create and send an event. Args: - soft_failed (bool): Whether to create a soft failed event or not - prev_event_ids (list[str]|None): Explicitly set the prev events, + soft_failed: Whether to create a soft failed event or not + prev_event_ids: Explicitly set the prev events, or if None just use the default Returns: - str: The new event's ID. + The new event's ID. """ event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) @@ -706,7 +738,7 @@ class HomeserverTestCase(TestCase): return event.event_id - def inject_room_member(self, room: str, user: str, membership: Membership) -> None: + def inject_room_member(self, room: str, user: str, membership: str) -> None: """ Inject a membership event into a room. @@ -766,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): path: str, content: Optional[JsonDict] = None, await_result: bool = True, - custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + custom_headers: Optional[Iterable[CustomHeaderType]] = None, client_ip: str = "127.0.0.1", ) -> FakeChannel: """Make an inbound signed federation request to this server @@ -799,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): self.site, method=method, path=path, - content=content, + content=content or "", shorthand=False, await_result=await_result, custom_headers=custom_headers, @@ -878,9 +910,6 @@ def override_config(extra_config): return decorator -TV = TypeVar("TV") - - def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]: """A test decorator which will skip the decorated test unless a condition is set |