diff --git a/tests/unittest.py b/tests/unittest.py
index a9b60b7eeb..165aafc574 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,20 @@ import inspect
import logging
import secrets
import time
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import (
+ Any,
+ AnyStr,
+ Callable,
+ ClassVar,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
from unittest.mock import Mock, patch
from canonicaljson import json
@@ -31,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
+from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
@@ -45,6 +59,7 @@ from synapse.logging.context import (
current_context,
set_current_context,
)
+from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -81,16 +96,13 @@ def around(target):
return _around
-T = TypeVar("T")
-
-
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
- def __init__(self, methodName, *args, **kwargs):
- super().__init__(methodName, *args, **kwargs)
+ def __init__(self, methodName: str):
+ super().__init__(methodName)
method = getattr(self, methodName)
@@ -204,18 +216,18 @@ class HomeserverTestCase(TestCase):
config dict.
Attributes:
- servlets (list[function]): List of servlet registration function.
+ servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
- hijack_auth (bool): Whether to hijack auth to return the user specified
+ hijack_auth: Whether to hijack auth to return the user specified
in user_id.
"""
- servlets = []
- hijack_auth = True
- needs_threadpool = False
+ hijack_auth: ClassVar[bool] = True
+ needs_threadpool: ClassVar[bool] = False
+ servlets: ClassVar[List[RegisterServletsFunc]] = []
- def __init__(self, methodName, *args, **kwargs):
- super().__init__(methodName, *args, **kwargs)
+ def __init__(self, methodName: str):
+ super().__init__(methodName)
# see if we have any additional config for this test
method = getattr(self, methodName)
@@ -287,9 +299,10 @@ class HomeserverTestCase(TestCase):
None,
)
- self.hs.get_auth().get_user_by_req = get_user_by_req
- self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
- self.hs.get_auth().get_access_token_from_request = Mock(
+ # Type ignore: mypy doesn't like us assigning to methods.
+ self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
+ self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
+ self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
)
@@ -318,7 +331,12 @@ class HomeserverTestCase(TestCase):
time.sleep(0.01)
def wait_for_background_updates(self) -> None:
- """Block until all background database updates have completed."""
+ """
+ Block until all background database updates have completed.
+
+ Note that callers must ensure that's a store property created on the
+ testcase.
+ """
while not self.get_success(
self.store.db_pool.updates.has_completed_background_updates()
):
@@ -403,14 +421,12 @@ class HomeserverTestCase(TestCase):
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
- request: Type[T] = SynapseRequest,
+ request: Type[Request] = SynapseRequest,
shorthand: bool = True,
- federation_auth_origin: str = None,
+ federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@@ -425,7 +441,7 @@ class HomeserverTestCase(TestCase):
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
- federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+ federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
@@ -584,7 +600,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00notadmin"
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
- want_mac = want_mac.hexdigest()
+ want_mac_digest = want_mac.hexdigest()
body = json.dumps(
{
@@ -593,7 +609,7 @@ class HomeserverTestCase(TestCase):
"displayname": displayname,
"password": password,
"admin": admin,
- "mac": want_mac,
+ "mac": want_mac_digest,
"inhibit_login": True,
}
)
@@ -639,9 +655,7 @@ class HomeserverTestCase(TestCase):
username,
password,
device_id=None,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
"""
Log in a user, and get an access token. Requires the Login API be
|