summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-11-16 10:41:35 +0000
committerGitHub <noreply@github.com>2021-11-16 10:41:35 +0000
commit0dda1a79687b8375dd5b23763ba1585e5dad030d (patch)
treedd60cd7bd9585e775c1b65b23de8997db1298a39
parentchange 'Home Server' to one word 'homeserver' (#11320) (diff)
downloadsynapse-0dda1a79687b8375dd5b23763ba1585e5dad030d.tar.xz
Misc typing fixes for tests, part 2 of N (#11330)
Diffstat (limited to '')
-rw-r--r--changelog.d/11330.misc1
-rw-r--r--tests/handlers/test_register.py9
-rw-r--r--tests/rest/client/utils.py51
-rw-r--r--tests/server.py3
-rw-r--r--tests/unittest.py31
5 files changed, 66 insertions, 29 deletions
diff --git a/changelog.d/11330.misc b/changelog.d/11330.misc
new file mode 100644
index 0000000000..86f26543dd
--- /dev/null
+++ b/changelog.d/11330.misc
@@ -0,0 +1 @@
+Improve type annotations in Synapse's test suite.
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index db691c4c1c..cd6f2c77ae 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_not_blocked(self):
-        self.store.count_monthly_users = Mock(
+        # Type ignore: mypy doesn't like us assigning to methods.
+        self.store.count_monthly_users = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
         )
         # Ensure does not throw exception
@@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": True})
     def test_get_or_create_user_mau_blocked(self):
-        self.store.get_monthly_active_count = Mock(
+        # Type ignore: mypy doesn't like us assigning to methods.
+        self.store.get_monthly_active_count = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.lots_of_users)
         )
         self.get_failure(
@@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
 
-        self.store.get_monthly_active_count = Mock(
+        # Type ignore: mypy doesn't like us assigning to methods.
+        self.store.get_monthly_active_count = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(self.hs.config.server.max_mau_value)
         )
         self.get_failure(
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 7cf782e2d6..1af5e5cee5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -28,11 +28,12 @@ from typing import (
     MutableMapping,
     Optional,
     Tuple,
-    Union,
+    overload,
 )
 from unittest.mock import patch
 
 import attr
+from typing_extensions import Literal
 
 from twisted.web.resource import Resource
 from twisted.web.server import Site
@@ -55,6 +56,32 @@ class RestHelper:
     site = attr.ib(type=Site)
     auth_user_id = attr.ib()
 
+    @overload
+    def create_room_as(
+        self,
+        room_creator: Optional[str] = ...,
+        is_public: Optional[bool] = ...,
+        room_version: Optional[str] = ...,
+        tok: Optional[str] = ...,
+        expect_code: Literal[200] = ...,
+        extra_content: Optional[Dict] = ...,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+    ) -> str:
+        ...
+
+    @overload
+    def create_room_as(
+        self,
+        room_creator: Optional[str] = ...,
+        is_public: Optional[bool] = ...,
+        room_version: Optional[str] = ...,
+        tok: Optional[str] = ...,
+        expect_code: int = ...,
+        extra_content: Optional[Dict] = ...,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+    ) -> Optional[str]:
+        ...
+
     def create_room_as(
         self,
         room_creator: Optional[str] = None,
@@ -64,7 +91,7 @@ class RestHelper:
         expect_code: int = 200,
         extra_content: Optional[Dict] = None,
         custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-    ) -> str:
+    ) -> Optional[str]:
         """
         Create a room.
 
@@ -107,6 +134,8 @@ class RestHelper:
 
         if expect_code == 200:
             return channel.json_body["room_id"]
+        else:
+            return None
 
     def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
         self.change_membership(
@@ -176,7 +205,7 @@ class RestHelper:
         extra_data: Optional[dict] = None,
         tok: Optional[str] = None,
         expect_code: int = 200,
-        expect_errcode: str = None,
+        expect_errcode: Optional[str] = None,
     ) -> None:
         """
         Send a membership state event into a room.
@@ -260,9 +289,7 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         if txn_id is None:
             txn_id = "m%s" % (str(time.time()))
@@ -509,7 +536,7 @@ class RestHelper:
             went.
         """
 
-        cookies = {}
+        cookies: Dict[str, str] = {}
 
         # if we're doing a ui auth, hit the ui auth redirect endpoint
         if ui_auth_session_id:
@@ -631,7 +658,13 @@ class RestHelper:
 
         # hit the redirect url again with the right Host header, which should now issue
         # a cookie and redirect to the SSO provider.
-        location = channel.headers.getRawHeaders("Location")[0]
+        def get_location(channel: FakeChannel) -> str:
+            location_values = channel.headers.getRawHeaders("Location")
+            # Keep mypy happy by asserting that location_values is nonempty
+            assert location_values
+            return location_values[0]
+
+        location = get_location(channel)
         parts = urllib.parse.urlsplit(location)
         channel = make_request(
             self.hs.get_reactor(),
@@ -645,7 +678,7 @@ class RestHelper:
 
         assert channel.code == 302
         channel.extract_cookies(cookies)
-        return channel.headers.getRawHeaders("Location")[0]
+        return get_location(channel)
 
     def initiate_sso_ui_auth(
         self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
diff --git a/tests/server.py b/tests/server.py
index a7cc5cd325..40cf5b12c3 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -24,6 +24,7 @@ from typing import (
     MutableMapping,
     Optional,
     Tuple,
+    Type,
     Union,
 )
 
@@ -226,7 +227,7 @@ def make_request(
     path: Union[bytes, str],
     content: Union[bytes, str, JsonDict] = b"",
     access_token: Optional[str] = None,
-    request: Request = SynapseRequest,
+    request: Type[Request] = SynapseRequest,
     shorthand: bool = True,
     federation_auth_origin: Optional[bytes] = None,
     content_is_form: bool = False,
diff --git a/tests/unittest.py b/tests/unittest.py
index ba830618c2..c9a08a3420 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -44,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.trial import unittest
 from twisted.web.resource import Resource
+from twisted.web.server import Request
 
 from synapse import events
 from synapse.api.constants import EventTypes, Membership
@@ -95,16 +96,13 @@ def around(target):
     return _around
 
 
-T = TypeVar("T")
-
-
 class TestCase(unittest.TestCase):
     """A subclass of twisted.trial's TestCase which looks for 'loglevel'
     attributes on both itself and its individual test methods, to override the
     root logger's logging level while that test (case|method) runs."""
 
-    def __init__(self, methodName, *args, **kwargs):
-        super().__init__(methodName, *args, **kwargs)
+    def __init__(self, methodName: str):
+        super().__init__(methodName)
 
         method = getattr(self, methodName)
 
@@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase):
     Attributes:
         servlets: List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
-        hijack_auth (bool): Whether to hijack auth to return the user specified
+        hijack_auth: Whether to hijack auth to return the user specified
         in user_id.
     """
 
-    hijack_auth = True
-    needs_threadpool = False
+    hijack_auth: ClassVar[bool] = True
+    needs_threadpool: ClassVar[bool] = False
     servlets: ClassVar[List[RegisterServletsFunc]] = []
 
-    def __init__(self, methodName, *args, **kwargs):
-        super().__init__(methodName, *args, **kwargs)
+    def __init__(self, methodName: str):
+        super().__init__(methodName)
 
         # see if we have any additional config for this test
         method = getattr(self, methodName)
@@ -301,9 +299,10 @@ class HomeserverTestCase(TestCase):
                         None,
                     )
 
-                self.hs.get_auth().get_user_by_req = get_user_by_req
-                self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
-                self.hs.get_auth().get_access_token_from_request = Mock(
+                # Type ignore: mypy doesn't like us assigning to methods.
+                self.hs.get_auth().get_user_by_req = get_user_by_req  # type: ignore[assignment]
+                self.hs.get_auth().get_user_by_access_token = get_user_by_access_token  # type: ignore[assignment]
+                self.hs.get_auth().get_access_token_from_request = Mock(  # type: ignore[assignment]
                     return_value="1234"
                 )
 
@@ -417,7 +416,7 @@ class HomeserverTestCase(TestCase):
         path: Union[bytes, str],
         content: Union[bytes, str, JsonDict] = b"",
         access_token: Optional[str] = None,
-        request: Type[T] = SynapseRequest,
+        request: Type[Request] = SynapseRequest,
         shorthand: bool = True,
         federation_auth_origin: Optional[bytes] = None,
         content_is_form: bool = False,
@@ -596,7 +595,7 @@ class HomeserverTestCase(TestCase):
             nonce_str += b"\x00notadmin"
 
         want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
-        want_mac = want_mac.hexdigest()
+        want_mac_digest = want_mac.hexdigest()
 
         body = json.dumps(
             {
@@ -605,7 +604,7 @@ class HomeserverTestCase(TestCase):
                 "displayname": displayname,
                 "password": password,
                 "admin": admin,
-                "mac": want_mac,
+                "mac": want_mac_digest,
                 "inhibit_login": True,
             }
         )