diff --git a/tests/unittest.py b/tests/unittest.py
index 5b19065c71..9afa68c164 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -22,10 +22,11 @@ import secrets
import time
from typing import (
Any,
- AnyStr,
+ Awaitable,
Callable,
ClassVar,
Dict,
+ Generic,
Iterable,
List,
Optional,
@@ -39,6 +40,7 @@ from unittest.mock import Mock, patch
import canonicaljson
import signedjson.key
import unpaddedbase64
+from typing_extensions import Protocol
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
@@ -49,7 +51,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse import events
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -70,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
-from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
+from tests.server import (
+ CustomHeaderType,
+ FakeChannel,
+ get_clock,
+ make_request,
+ setup_test_homeserver,
+)
from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -78,6 +86,17 @@ from tests.utils import default_config, setupdb
setupdb()
setup_logging()
+TV = TypeVar("TV")
+_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
+
+
+class _TypedFailure(Generic[_ExcType], Protocol):
+ """Extension to twisted.Failure, where the 'value' has a certain type."""
+
+ @property
+ def value(self) -> _ExcType:
+ ...
+
def around(target):
"""A CLOS-style 'around' modifier, which wraps the original method of the
@@ -276,6 +295,7 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
+ assert self.helper.auth_user_id is not None
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
@@ -288,6 +308,7 @@ class HomeserverTestCase(TestCase):
)
async def get_user_by_access_token(token=None, allow_guest=False):
+ assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
@@ -295,6 +316,7 @@ class HomeserverTestCase(TestCase):
}
async def get_user_by_req(request, allow_guest=False, rights="access"):
+ assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
token_id,
@@ -311,7 +333,7 @@ class HomeserverTestCase(TestCase):
)
if self.needs_threadpool:
- self.reactor.threadpool = ThreadPool()
+ self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
self.addCleanup(self.reactor.threadpool.stop)
self.reactor.threadpool.start()
@@ -426,7 +448,7 @@ class HomeserverTestCase(TestCase):
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
- custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@@ -511,30 +533,36 @@ class HomeserverTestCase(TestCase):
return hs
- def pump(self, by=0.0):
+ def pump(self, by: float = 0.0) -> None:
"""
Pump the reactor enough that Deferreds will fire.
"""
self.reactor.pump([by] * 100)
- def get_success(self, d, by=0.0):
- deferred: Deferred[TV] = ensureDeferred(d)
+ def get_success(
+ self,
+ d: Awaitable[TV],
+ by: float = 0.0,
+ ) -> TV:
+ deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by)
return self.successResultOf(deferred)
- def get_failure(self, d, exc):
+ def get_failure(
+ self, d: Awaitable[Any], exc: Type[_ExcType]
+ ) -> _TypedFailure[_ExcType]:
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
- deferred: Deferred[Any] = ensureDeferred(d)
+ deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
self.pump()
return self.failureResultOf(deferred, exc)
- def get_success_or_raise(self, d, by=0.0):
+ def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
"""Drive deferred to completion and return result or raise exception
on failure.
"""
- deferred: Deferred[TV] = ensureDeferred(d)
+ deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
results: list = []
deferred.addBoth(results.append)
@@ -642,11 +670,11 @@ class HomeserverTestCase(TestCase):
def login(
self,
- username,
- password,
- device_id=None,
- custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
- ):
+ username: str,
+ password: str,
+ device_id: Optional[str] = None,
+ custom_headers: Optional[Iterable[CustomHeaderType]] = None,
+ ) -> str:
"""
Log in a user, and get an access token. Requires the Login API be
registered.
@@ -668,18 +696,22 @@ class HomeserverTestCase(TestCase):
return access_token
def create_and_send_event(
- self, room_id, user, soft_failed=False, prev_event_ids=None
- ):
+ self,
+ room_id: str,
+ user: UserID,
+ soft_failed: bool = False,
+ prev_event_ids: Optional[List[str]] = None,
+ ) -> str:
"""
Create and send an event.
Args:
- soft_failed (bool): Whether to create a soft failed event or not
- prev_event_ids (list[str]|None): Explicitly set the prev events,
+ soft_failed: Whether to create a soft failed event or not
+ prev_event_ids: Explicitly set the prev events,
or if None just use the default
Returns:
- str: The new event's ID.
+ The new event's ID.
"""
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
@@ -706,7 +738,7 @@ class HomeserverTestCase(TestCase):
return event.event_id
- def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+ def inject_room_member(self, room: str, user: str, membership: str) -> None:
"""
Inject a membership event into a room.
@@ -766,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
path: str,
content: Optional[JsonDict] = None,
await_result: bool = True,
- custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ custom_headers: Optional[Iterable[CustomHeaderType]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""Make an inbound signed federation request to this server
@@ -799,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
self.site,
method=method,
path=path,
- content=content,
+ content=content or "",
shorthand=False,
await_result=await_result,
custom_headers=custom_headers,
@@ -878,9 +910,6 @@ def override_config(extra_config):
return decorator
-TV = TypeVar("TV")
-
-
def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
"""A test decorator which will skip the decorated test unless a condition is set
|