summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12347.misc1
-rw-r--r--mypy.ini1
-rw-r--r--tests/handlers/test_e2e_keys.py6
-rw-r--r--tests/handlers/test_federation.py5
-rw-r--r--tests/handlers/test_oidc.py7
-rw-r--r--tests/handlers/test_user_directory.py2
-rw-r--r--tests/rest/admin/test_media.py8
-rw-r--r--tests/rest/admin/test_user.py15
-rw-r--r--tests/server.py6
-rw-r--r--tests/storage/databases/main/test_lock.py8
-rw-r--r--tests/storage/test_appservice.py1
-rw-r--r--tests/unittest.py85
12 files changed, 97 insertions, 48 deletions
diff --git a/changelog.d/12347.misc b/changelog.d/12347.misc
new file mode 100644
index 0000000000..1f6f584e6d
--- /dev/null
+++ b/changelog.d/12347.misc
@@ -0,0 +1 @@
+Add type annotations for `tests/unittest.py`.
diff --git a/mypy.ini b/mypy.ini
index 84e6b8646e..85291099ac 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -83,7 +83,6 @@ exclude = (?x)
    |tests/test_server.py
    |tests/test_state.py
    |tests/test_terms_auth.py
-   |tests/unittest.py
    |tests/util/caches/test_cached_call.py
    |tests/util/caches/test_deferred_cache.py
    |tests/util/caches/test_descriptors.py
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ac21a28c43..8c74ed1fcf 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -463,8 +463,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         res = e.value.code
         self.assertEqual(res, 400)
 
-        res = self.get_success(self.handler.query_local_devices({local_user: None}))
-        self.assertDictEqual(res, {local_user: {}})
+        query_res = self.get_success(
+            self.handler.query_local_devices({local_user: None})
+        )
+        self.assertDictEqual(query_res, {local_user: {}})
 
     def test_upload_signatures(self) -> None:
         """should check signatures that are uploaded"""
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 0fa5045301..060ba5f517 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -375,7 +375,8 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
         member_event.signatures = member_event_dict["signatures"]
 
         # Add the new member_event to the StateMap
-        prev_state_map[
+        updated_state_map = dict(prev_state_map)
+        updated_state_map[
             (member_event.type, member_event.state_key)
         ] = member_event.event_id
         auth_events.append(member_event)
@@ -399,7 +400,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
                 prev_event_ids=message_event_dict["prev_events"],
                 auth_event_ids=self._event_auth_handler.compute_auth_events(
                     builder,
-                    prev_state_map,
+                    updated_state_map,
                     for_verification=False,
                 ),
                 depth=message_event_dict["depth"],
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 014815db6e..9684120c70 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -354,10 +354,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
         req = Mock(spec=["cookies"])
         req.cookies = []
 
-        url = self.get_success(
-            self.provider.handle_redirect_request(req, b"http://client/redirect")
+        url = urlparse(
+            self.get_success(
+                self.provider.handle_redirect_request(req, b"http://client/redirect")
+            )
         )
-        url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
 
         self.assertEqual(url.scheme, auth_endpoint.scheme)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 92012cd6f7..c6e501c7be 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -351,6 +351,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.handler.handle_local_profile_change(regular_user_id, profile_info)
         )
         profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
+        assert profile is not None
         self.assertTrue(profile["display_name"] == display_name)
 
     def test_handle_local_profile_change_with_deactivated_user(self) -> None:
@@ -369,6 +370,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # profile is in directory
         profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        assert profile is not None
         self.assertTrue(profile["display_name"] == display_name)
 
         # deactivate user
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 0d47dd0aff..e909e444ac 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -702,6 +702,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
         """
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["quarantined_by"])
 
         # quarantining
@@ -715,6 +716,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertTrue(media_info["quarantined_by"])
 
         # remove from quarantine
@@ -728,6 +730,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["quarantined_by"])
 
     def test_quarantine_protected_media(self) -> None:
@@ -740,6 +743,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
         # verify protection
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertTrue(media_info["safe_from_quarantine"])
 
         # quarantining
@@ -754,6 +758,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
 
         # verify that is not in quarantine
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["quarantined_by"])
 
 
@@ -830,6 +835,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
         """
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["safe_from_quarantine"])
 
         # protect
@@ -843,6 +849,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertTrue(media_info["safe_from_quarantine"])
 
         # unprotect
@@ -856,6 +863,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertFalse(channel.json_body)
 
         media_info = self.get_success(self.store.get_local_media(self.media_id))
+        assert media_info is not None
         self.assertFalse(media_info["safe_from_quarantine"])
 
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index bef911d5df..0cdf1dec40 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1590,10 +1590,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
 
-        pushers = self.get_success(
-            self.store.get_pushers_by({"user_name": "@bob:test"})
+        pushers = list(
+            self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual("@bob:test", pushers[0].user_name)
 
@@ -1632,10 +1631,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
 
-        pushers = self.get_success(
-            self.store.get_pushers_by({"user_name": "@bob:test"})
+        pushers = list(
+            self.get_success(self.store.get_pushers_by({"user_name": "@bob:test"}))
         )
-        pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
 
     def test_set_password(self) -> None:
@@ -2144,6 +2142,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # is in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
+        assert profile is not None
         self.assertTrue(profile["display_name"] == "User")
 
         # Deactivate user
@@ -2711,6 +2710,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         user_tuple = self.get_success(
             self.store.get_user_by_access_token(other_user_token)
         )
+        assert user_tuple is not None
         token_id = user_tuple.token_id
 
         self.get_success(
@@ -3676,6 +3676,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
         # The user starts off as not shadow-banned.
         other_user_token = self.login("user", "pass")
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        assert result is not None
         self.assertFalse(result.shadow_banned)
 
         channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
@@ -3684,6 +3685,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
 
         # Ensure the user is shadow-banned (and the cache was cleared).
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        assert result is not None
         self.assertTrue(result.shadow_banned)
 
         # Un-shadow-ban the user.
@@ -3695,6 +3697,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
 
         # Ensure the user is no longer shadow-banned (and the cache was cleared).
         result = self.get_success(self.store.get_user_by_access_token(other_user_token))
+        assert result is not None
         self.assertFalse(result.shadow_banned)
 
 
diff --git a/tests/server.py b/tests/server.py
index 6ce2a17bf4..aaa5ca3e74 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -22,7 +22,6 @@ import warnings
 from collections import deque
 from io import SEEK_END, BytesIO
 from typing import (
-    AnyStr,
     Callable,
     Dict,
     Iterable,
@@ -86,6 +85,9 @@ from tests.utils import (
 
 logger = logging.getLogger(__name__)
 
+# the type of thing that can be passed into `make_request` in the headers list
+CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
+
 
 class TimedOutException(Exception):
     """
@@ -260,7 +262,7 @@ def make_request(
     federation_auth_origin: Optional[bytes] = None,
     content_is_form: bool = False,
     await_result: bool = True,
-    custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+    custom_headers: Optional[Iterable[CustomHeaderType]] = None,
     client_ip: str = "127.0.0.1",
 ) -> FakeChannel:
     """
diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py
index 3ac4646969..74c6224eb6 100644
--- a/tests/storage/databases/main/test_lock.py
+++ b/tests/storage/databases/main/test_lock.py
@@ -28,7 +28,7 @@ class LockTestCase(unittest.HomeserverTestCase):
         """
         # First to acquire this lock, so it should complete
         lock = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock)
+        assert lock is not None
 
         # Enter the context manager
         self.get_success(lock.__aenter__())
@@ -45,7 +45,7 @@ class LockTestCase(unittest.HomeserverTestCase):
 
         # We can now acquire the lock again.
         lock3 = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock3)
+        assert lock3 is not None
         self.get_success(lock3.__aenter__())
         self.get_success(lock3.__aexit__(None, None, None))
 
@@ -53,7 +53,7 @@ class LockTestCase(unittest.HomeserverTestCase):
         """Test that we don't time out locks while they're still active"""
 
         lock = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock)
+        assert lock is not None
 
         self.get_success(lock.__aenter__())
 
@@ -69,7 +69,7 @@ class LockTestCase(unittest.HomeserverTestCase):
         """Test that we time out locks if they're not updated for ages"""
 
         lock = self.get_success(self.store.try_acquire_lock("name", "key"))
-        self.assertIsNotNone(lock)
+        assert lock is not None
 
         self.get_success(lock.__aenter__())
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 08078d38e2..1bf93e79a7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -358,6 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(self._insert_txn(service.id, 12, other_events))
 
         txn = self.get_success(self.store.get_oldest_unsent_txn(service))
+        assert txn is not None
         self.assertEqual(service, txn.service)
         self.assertEqual(10, txn.id)
         self.assertEqual(events, txn.events)
diff --git a/tests/unittest.py b/tests/unittest.py
index 5b19065c71..9afa68c164 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -22,10 +22,11 @@ import secrets
 import time
 from typing import (
     Any,
-    AnyStr,
+    Awaitable,
     Callable,
     ClassVar,
     Dict,
+    Generic,
     Iterable,
     List,
     Optional,
@@ -39,6 +40,7 @@ from unittest.mock import Mock, patch
 import canonicaljson
 import signedjson.key
 import unpaddedbase64
+from typing_extensions import Protocol
 
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
@@ -49,7 +51,7 @@ from twisted.web.resource import Resource
 from twisted.web.server import Request
 
 from synapse import events
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -70,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 
-from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
+from tests.server import (
+    CustomHeaderType,
+    FakeChannel,
+    get_clock,
+    make_request,
+    setup_test_homeserver,
+)
 from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
@@ -78,6 +86,17 @@ from tests.utils import default_config, setupdb
 setupdb()
 setup_logging()
 
+TV = TypeVar("TV")
+_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
+
+
+class _TypedFailure(Generic[_ExcType], Protocol):
+    """Extension to twisted.Failure, where the 'value' has a certain type."""
+
+    @property
+    def value(self) -> _ExcType:
+        ...
+
 
 def around(target):
     """A CLOS-style 'around' modifier, which wraps the original method of the
@@ -276,6 +295,7 @@ class HomeserverTestCase(TestCase):
 
         if hasattr(self, "user_id"):
             if self.hijack_auth:
+                assert self.helper.auth_user_id is not None
 
                 # We need a valid token ID to satisfy foreign key constraints.
                 token_id = self.get_success(
@@ -288,6 +308,7 @@ class HomeserverTestCase(TestCase):
                 )
 
                 async def get_user_by_access_token(token=None, allow_guest=False):
+                    assert self.helper.auth_user_id is not None
                     return {
                         "user": UserID.from_string(self.helper.auth_user_id),
                         "token_id": token_id,
@@ -295,6 +316,7 @@ class HomeserverTestCase(TestCase):
                     }
 
                 async def get_user_by_req(request, allow_guest=False, rights="access"):
+                    assert self.helper.auth_user_id is not None
                     return create_requester(
                         UserID.from_string(self.helper.auth_user_id),
                         token_id,
@@ -311,7 +333,7 @@ class HomeserverTestCase(TestCase):
                 )
 
         if self.needs_threadpool:
-            self.reactor.threadpool = ThreadPool()
+            self.reactor.threadpool = ThreadPool()  # type: ignore[assignment]
             self.addCleanup(self.reactor.threadpool.stop)
             self.reactor.threadpool.start()
 
@@ -426,7 +448,7 @@ class HomeserverTestCase(TestCase):
         federation_auth_origin: Optional[bytes] = None,
         content_is_form: bool = False,
         await_result: bool = True,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
@@ -511,30 +533,36 @@ class HomeserverTestCase(TestCase):
 
         return hs
 
-    def pump(self, by=0.0):
+    def pump(self, by: float = 0.0) -> None:
         """
         Pump the reactor enough that Deferreds will fire.
         """
         self.reactor.pump([by] * 100)
 
-    def get_success(self, d, by=0.0):
-        deferred: Deferred[TV] = ensureDeferred(d)
+    def get_success(
+        self,
+        d: Awaitable[TV],
+        by: float = 0.0,
+    ) -> TV:
+        deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
         self.pump(by=by)
         return self.successResultOf(deferred)
 
-    def get_failure(self, d, exc):
+    def get_failure(
+        self, d: Awaitable[Any], exc: Type[_ExcType]
+    ) -> _TypedFailure[_ExcType]:
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
-        deferred: Deferred[Any] = ensureDeferred(d)
+        deferred: Deferred[Any] = ensureDeferred(d)  # type: ignore[arg-type]
         self.pump()
         return self.failureResultOf(deferred, exc)
 
-    def get_success_or_raise(self, d, by=0.0):
+    def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
         """Drive deferred to completion and return result or raise exception
         on failure.
         """
-        deferred: Deferred[TV] = ensureDeferred(d)
+        deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
 
         results: list = []
         deferred.addBoth(results.append)
@@ -642,11 +670,11 @@ class HomeserverTestCase(TestCase):
 
     def login(
         self,
-        username,
-        password,
-        device_id=None,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-    ):
+        username: str,
+        password: str,
+        device_id: Optional[str] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
+    ) -> str:
         """
         Log in a user, and get an access token. Requires the Login API be
         registered.
@@ -668,18 +696,22 @@ class HomeserverTestCase(TestCase):
         return access_token
 
     def create_and_send_event(
-        self, room_id, user, soft_failed=False, prev_event_ids=None
-    ):
+        self,
+        room_id: str,
+        user: UserID,
+        soft_failed: bool = False,
+        prev_event_ids: Optional[List[str]] = None,
+    ) -> str:
         """
         Create and send an event.
 
         Args:
-            soft_failed (bool): Whether to create a soft failed event or not
-            prev_event_ids (list[str]|None): Explicitly set the prev events,
+            soft_failed: Whether to create a soft failed event or not
+            prev_event_ids: Explicitly set the prev events,
                 or if None just use the default
 
         Returns:
-            str: The new event's ID.
+            The new event's ID.
         """
         event_creator = self.hs.get_event_creation_handler()
         requester = create_requester(user)
@@ -706,7 +738,7 @@ class HomeserverTestCase(TestCase):
 
         return event.event_id
 
-    def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+    def inject_room_member(self, room: str, user: str, membership: str) -> None:
         """
         Inject a membership event into a room.
 
@@ -766,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
         path: str,
         content: Optional[JsonDict] = None,
         await_result: bool = True,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """Make an inbound signed federation request to this server
@@ -799,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
             self.site,
             method=method,
             path=path,
-            content=content,
+            content=content or "",
             shorthand=False,
             await_result=await_result,
             custom_headers=custom_headers,
@@ -878,9 +910,6 @@ def override_config(extra_config):
     return decorator
 
 
-TV = TypeVar("TV")
-
-
 def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
     """A test decorator which will skip the decorated test unless a condition is set