summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-11-12 15:50:54 +0000
committerGitHub <noreply@github.com>2021-11-12 15:50:54 +0000
commit4c96ce396e900a94af66ec070af925881b6e1e24 (patch)
tree446920e10d66c2ace553a68541ed64e0fbd543e1 /tests
parentGeneralize the disallowed_untyped_defs in mypy.ini (#11322) (diff)
downloadsynapse-4c96ce396e900a94af66ec070af925881b6e1e24.tar.xz
Misc typing fixes for `tests`, part 1 of N (#11323)
* Annotate HomeserverTestCase.servlets
* Correct annotation of federation_auth_origin
* Use AnyStr custom_headers instead of a Union

This allows (str, str) and (bytes, bytes).
This disallows (str, bytes) and (bytes, str)

* DomainSpecificString.SIGIL is a ClassVar
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/_base.py5
-rw-r--r--tests/rest/client/utils.py22
-rw-r--r--tests/server.py15
-rw-r--r--tests/unittest.py32
4 files changed, 47 insertions, 27 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eac4664b41..cb02eddf07 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet.protocol import Protocol
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
@@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     unlike `BaseStreamTestCase`.
     """
 
-    servlets: List[Callable[[HomeServer, JsonResource], None]] = []
-
     def setUp(self):
         super().setUp()
 
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..7cf782e2d6 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,7 +19,17 @@ import json
 import re
 import time
 import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+    Any,
+    AnyStr,
+    Dict,
+    Iterable,
+    Mapping,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Union,
+)
 from unittest.mock import patch
 
 import attr
@@ -53,9 +63,7 @@ class RestHelper:
         tok: Optional[str] = None,
         expect_code: int = 200,
         extra_content: Optional[Dict] = None,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ) -> str:
         """
         Create a room.
@@ -227,9 +235,7 @@ class RestHelper:
         txn_id=None,
         tok=None,
         expect_code=200,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         if body is None:
             body = "body_text_here"
@@ -418,7 +424,7 @@ class RestHelper:
             path,
             content=image_data,
             access_token=tok,
-            custom_headers=[(b"Content-Length", str(image_length))],
+            custom_headers=[("Content-Length", str(image_length))],
         )
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
diff --git a/tests/server.py b/tests/server.py
index 103351b487..a7cc5cd325 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,7 +16,16 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import (
+    AnyStr,
+    Callable,
+    Dict,
+    Iterable,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
 from typing_extensions import Deque
@@ -222,9 +231,7 @@ def make_request(
     federation_auth_origin: Optional[bytes] = None,
     content_is_form: bool = False,
     await_result: bool = True,
-    custom_headers: Optional[
-        Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-    ] = None,
+    custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     client_ip: str = "127.0.0.1",
 ) -> FakeChannel:
     """
diff --git a/tests/unittest.py b/tests/unittest.py
index a9b60b7eeb..ba830618c2 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,20 @@ import inspect
 import logging
 import secrets
 import time
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+    Any,
+    AnyStr,
+    Callable,
+    ClassVar,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
 from unittest.mock import Mock, patch
 
 from canonicaljson import json
@@ -45,6 +58,7 @@ from synapse.logging.context import (
     current_context,
     set_current_context,
 )
+from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
@@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase):
       config dict.
 
     Attributes:
-        servlets (list[function]): List of servlet registration function.
+        servlets: List of servlet registration function.
         user_id (str): The user ID to assume if auth is hijacked.
         hijack_auth (bool): Whether to hijack auth to return the user specified
         in user_id.
     """
 
-    servlets = []
     hijack_auth = True
     needs_threadpool = False
+    servlets: ClassVar[List[RegisterServletsFunc]] = []
 
     def __init__(self, methodName, *args, **kwargs):
         super().__init__(methodName, *args, **kwargs)
@@ -405,12 +419,10 @@ class HomeserverTestCase(TestCase):
         access_token: Optional[str] = None,
         request: Type[T] = SynapseRequest,
         shorthand: bool = True,
-        federation_auth_origin: str = None,
+        federation_auth_origin: Optional[bytes] = None,
         content_is_form: bool = False,
         await_result: bool = True,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
@@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase):
             a dict.
             shorthand: Whether to try and be helpful and prefix the given URL
             with the usual REST API path, if it doesn't contain it.
-            federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+            federation_auth_origin: if set to not-None, we will add a fake
                 Authorization header pretenting to be the given server name.
             content_is_form: Whether the content is URL encoded form data. Adds the
                 'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -639,9 +651,7 @@ class HomeserverTestCase(TestCase):
         username,
         password,
         device_id=None,
-        custom_headers: Optional[
-            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
-        ] = None,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
     ):
         """
         Log in a user, and get an access token. Requires the Login API be