summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11323.misc1
-rw-r--r--synapse/rest/__init__.py4
-rw-r--r--synapse/types.py3
-rw-r--r--tests/replication/_base.py5
-rw-r--r--tests/rest/client/utils.py22
-rw-r--r--tests/server.py15
-rw-r--r--tests/unittest.py32
7 files changed, 53 insertions, 29 deletions
diff --git a/changelog.d/11323.misc b/changelog.d/11323.misc
new file mode 100644
index 0000000000..54f39e1844
--- /dev/null
+++ b/changelog.d/11323.misc
@@ -0,0 +1 @@
+Improve type annotations in Synapse's test suite.
\ No newline at end of file
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index e04af705eb..cebdeecb81 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable
 
 from synapse.http.server import HttpServer, JsonResource
 from synapse.rest import admin
@@ -62,6 +62,8 @@ from synapse.rest.client import (
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]
+
 
 class ClientRestResource(JsonResource):
     """Matrix Client API REST resource.
diff --git a/synapse/types.py b/synapse/types.py
index 9d7a675662..fb72f19343 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -19,6 +19,7 @@ from collections import namedtuple
 from typing import (
     TYPE_CHECKING,
     Any,
+    ClassVar,
     Dict,
     Mapping,
     MutableMapping,
@@ -219,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
         'domain' : The domain part of the name
     """
 
-    SIGIL: str = abc.abstractproperty()  # type: ignore
+    SIGIL: ClassVar[str] = abc.abstractproperty()  # type: ignore
 
     localpart = attr.ib(type=str)
     domain = attr.ib(type=str)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eac4664b41..cb02eddf07 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet.protocol import Protocol
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
@@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     unlike `BaseStreamTestCase`.
     """
 
-    servlets: List[Callable[[HomeServer, JsonResource], None]] = []
-
     def setUp(self):
         super().setUp()
 
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..7cf782e2d6 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,7 +19,17 @@ import json
 import re
 import time
 import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    AnyStr,
+    Dict,
+    Iterable,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Union,
+)
 from unittest.mock import patch
 
 import attr
@@ -53,9 +63,7 @@ class RestHelper:
         tok: Optional[str] = None,
         expect_code: int = 200,
         extra_content: Optional[Dict] = None,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ) -> str:
         """
         Create a room.
@@ -227,9 +235,7 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         if body is None:
             body = "body_text_here"
@@ -418,7 +424,7 @@ class RestHelper:
             path,
             content=image_data,
             access_token=tok,
-            custom_headers=[(b"Content-Length", str(image_length))],
+            custom_headers=[("Content-Length", str(image_length))],
         )
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
diff --git a/tests/server.py b/tests/server.py
index 103351b487..a7cc5cd325 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,7 +16,16 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import (
+    AnyStr,
+    Callable,
+    Dict,
+    Iterable,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
 from typing_extensions import Deque
@@ -222,9 +231,7 @@ def make_request(
     federation_auth_origin: Optional[bytes] = None,
     content_is_form: bool = False,
     await_result: bool = True,
-    custom_headers: Optional[
-        Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-    ] = None,
+    custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     client_ip: str = "127.0.0.1",
 ) -> FakeChannel:
     """
diff --git a/tests/unittest.py b/tests/unittest.py
index a9b60b7eeb..ba830618c2 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,20 @@ import inspect
 import logging
 import secrets
 import time
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+    Any,
+    AnyStr,
+    Callable,
+    ClassVar,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
 from unittest.mock import Mock, patch
 
 from canonicaljson import json
@@ -45,6 +58,7 @@ from synapse.logging.context import (
     current_context,
     set_current_context,
 )
+from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
@@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase):
       config dict.
 
     Attributes:
-        servlets (list[function]): List of servlet registration function.
+        servlets: List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
         hijack_auth (bool): Whether to hijack auth to return the user specified
         in user_id.
     """
 
-    servlets = []
     hijack_auth = True
     needs_threadpool = False
+    servlets: ClassVar[List[RegisterServletsFunc]] = []
 
     def __init__(self, methodName, *args, **kwargs):
         super().__init__(methodName, *args, **kwargs)
@@ -405,12 +419,10 @@ class HomeserverTestCase(TestCase):
         access_token: Optional[str] = None,
         request: Type[T] = SynapseRequest,
         shorthand: bool = True,
-        federation_auth_origin: str = None,
+        federation_auth_origin: Optional[bytes] = None,
         content_is_form: bool = False,
         await_result: bool = True,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
@@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase):
             a dict.
             shorthand: Whether to try and be helpful and prefix the given URL
             with the usual REST API path, if it doesn't contain it.
-            federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+            federation_auth_origin: if set to not-None, we will add a fake
                 Authorization header pretenting to be the given server name.
             content_is_form: Whether the content is URL encoded form data. Adds the
                 'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -639,9 +651,7 @@ class HomeserverTestCase(TestCase):
         username,
         password,
         device_id=None,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         """
         Log in a user, and get an access token. Requires the Login API be