diff --git a/changelog.d/11323.misc b/changelog.d/11323.misc
new file mode 100644
index 0000000000..54f39e1844
--- /dev/null
+++ b/changelog.d/11323.misc
@@ -0,0 +1 @@
+Improve type annotations in Synapse's test suite.
\ No newline at end of file
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index e04af705eb..cebdeecb81 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Callable
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
@@ -62,6 +62,8 @@ from synapse.rest.client import (
if TYPE_CHECKING:
from synapse.server import HomeServer
+RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]
+
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
diff --git a/synapse/types.py b/synapse/types.py
index 9d7a675662..fb72f19343 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -19,6 +19,7 @@ from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
+ ClassVar,
Dict,
Mapping,
MutableMapping,
@@ -219,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
'domain' : The domain part of the name
"""
- SIGIL: str = abc.abstractproperty() # type: ignore
+ SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore
localpart = attr.ib(type=str)
domain = attr.ib(type=str)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index eac4664b41..cb02eddf07 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
-from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
@@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
- servlets: List[Callable[[HomeServer, JsonResource], None]] = []
-
def setUp(self):
super().setUp()
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index ec0979850b..7cf782e2d6 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -19,7 +19,17 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
+from typing import (
+ Any,
+ AnyStr,
+ Dict,
+ Iterable,
+ Mapping,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Union,
+)
from unittest.mock import patch
import attr
@@ -53,9 +63,7 @@ class RestHelper:
tok: Optional[str] = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> str:
"""
Create a room.
@@ -227,9 +235,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if body is None:
body = "body_text_here"
@@ -418,7 +424,7 @@ class RestHelper:
path,
content=image_data,
access_token=tok,
- custom_headers=[(b"Content-Length", str(image_length))],
+ custom_headers=[("Content-Length", str(image_length))],
)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
diff --git a/tests/server.py b/tests/server.py
index 103351b487..a7cc5cd325 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -16,7 +16,16 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
-from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import (
+ AnyStr,
+ Callable,
+ Dict,
+ Iterable,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Union,
+)
import attr
from typing_extensions import Deque
@@ -222,9 +231,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
diff --git a/tests/unittest.py b/tests/unittest.py
index a9b60b7eeb..ba830618c2 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,20 @@ import inspect
import logging
import secrets
import time
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+ Any,
+ AnyStr,
+ Callable,
+ ClassVar,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
from unittest.mock import Mock, patch
from canonicaljson import json
@@ -45,6 +58,7 @@ from synapse.logging.context import (
current_context,
set_current_context,
)
+from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase):
config dict.
Attributes:
- servlets (list[function]): List of servlet registration function.
+ servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id.
"""
- servlets = []
hijack_auth = True
needs_threadpool = False
+ servlets: ClassVar[List[RegisterServletsFunc]] = []
def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
@@ -405,12 +419,10 @@ class HomeserverTestCase(TestCase):
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
shorthand: bool = True,
- federation_auth_origin: str = None,
+ federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase):
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
- federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+ federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -639,9 +651,7 @@ class HomeserverTestCase(TestCase):
username,
password,
device_id=None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
"""
Log in a user, and get an access token. Requires the Login API be
|