summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py127
1 files changed, 59 insertions, 68 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 326895f4c9..9afa68c164 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -16,17 +16,17 @@
 import gc
 import hashlib
 import hmac
-import inspect
 import json
 import logging
 import secrets
 import time
 from typing import (
     Any,
-    AnyStr,
+    Awaitable,
     Callable,
     ClassVar,
     Dict,
+    Generic,
     Iterable,
     List,
     Optional,
@@ -40,6 +40,7 @@ from unittest.mock import Mock, patch
 import canonicaljson
 import signedjson.key
 import unpaddedbase64
+from typing_extensions import Protocol
 
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
@@ -50,7 +51,7 @@ from twisted.web.resource import Resource
 from twisted.web.server import Request
 
 from synapse import events
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
@@ -71,7 +72,13 @@ from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 
-from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
+from tests.server import (
+    CustomHeaderType,
+    FakeChannel,
+    get_clock,
+    make_request,
+    setup_test_homeserver,
+)
 from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
@@ -79,6 +86,17 @@ from tests.utils import default_config, setupdb
 setupdb()
 setup_logging()
 
+TV = TypeVar("TV")
+_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
+
+
+class _TypedFailure(Generic[_ExcType], Protocol):
+    """Extension to twisted.Failure, where the 'value' has a certain type."""
+
+    @property
+    def value(self) -> _ExcType:
+        ...
+
 
 def around(target):
     """A CLOS-style 'around' modifier, which wraps the original method of the
@@ -277,6 +295,7 @@ class HomeserverTestCase(TestCase):
 
         if hasattr(self, "user_id"):
             if self.hijack_auth:
+                assert self.helper.auth_user_id is not None
 
                 # We need a valid token ID to satisfy foreign key constraints.
                 token_id = self.get_success(
@@ -289,6 +308,7 @@ class HomeserverTestCase(TestCase):
                 )
 
                 async def get_user_by_access_token(token=None, allow_guest=False):
+                    assert self.helper.auth_user_id is not None
                     return {
                         "user": UserID.from_string(self.helper.auth_user_id),
                         "token_id": token_id,
@@ -296,6 +316,7 @@ class HomeserverTestCase(TestCase):
                     }
 
                 async def get_user_by_req(request, allow_guest=False, rights="access"):
+                    assert self.helper.auth_user_id is not None
                     return create_requester(
                         UserID.from_string(self.helper.auth_user_id),
                         token_id,
@@ -312,7 +333,7 @@ class HomeserverTestCase(TestCase):
                 )
 
         if self.needs_threadpool:
-            self.reactor.threadpool = ThreadPool()
+            self.reactor.threadpool = ThreadPool()  # type: ignore[assignment]
             self.addCleanup(self.reactor.threadpool.stop)
             self.reactor.threadpool.start()
 
@@ -427,7 +448,7 @@ class HomeserverTestCase(TestCase):
         federation_auth_origin: Optional[bytes] = None,
         content_is_form: bool = False,
         await_result: bool = True,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
@@ -512,40 +533,36 @@ class HomeserverTestCase(TestCase):
 
         return hs
 
-    def pump(self, by=0.0):
+    def pump(self, by: float = 0.0) -> None:
         """
         Pump the reactor enough that Deferreds will fire.
         """
         self.reactor.pump([by] * 100)
 
-    def get_success(self, d, by=0.0):
-        if inspect.isawaitable(d):
-            d = ensureDeferred(d)
-        if not isinstance(d, Deferred):
-            return d
+    def get_success(
+        self,
+        d: Awaitable[TV],
+        by: float = 0.0,
+    ) -> TV:
+        deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
         self.pump(by=by)
-        return self.successResultOf(d)
+        return self.successResultOf(deferred)
 
-    def get_failure(self, d, exc):
+    def get_failure(
+        self, d: Awaitable[Any], exc: Type[_ExcType]
+    ) -> _TypedFailure[_ExcType]:
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
-        if inspect.isawaitable(d):
-            d = ensureDeferred(d)
-        if not isinstance(d, Deferred):
-            return d
+        deferred: Deferred[Any] = ensureDeferred(d)  # type: ignore[arg-type]
         self.pump()
-        return self.failureResultOf(d, exc)
+        return self.failureResultOf(deferred, exc)
 
-    def get_success_or_raise(self, d, by=0.0):
+    def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
         """Drive deferred to completion and return result or raise exception
         on failure.
         """
-
-        if inspect.isawaitable(d):
-            deferred = ensureDeferred(d)
-        if not isinstance(deferred, Deferred):
-            return d
+        deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
 
         results: list = []
         deferred.addBoth(results.append)
@@ -653,11 +670,11 @@ class HomeserverTestCase(TestCase):
 
     def login(
         self,
-        username,
-        password,
-        device_id=None,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-    ):
+        username: str,
+        password: str,
+        device_id: Optional[str] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
+    ) -> str:
         """
         Log in a user, and get an access token. Requires the Login API be
         registered.
@@ -679,18 +696,22 @@ class HomeserverTestCase(TestCase):
         return access_token
 
     def create_and_send_event(
-        self, room_id, user, soft_failed=False, prev_event_ids=None
-    ):
+        self,
+        room_id: str,
+        user: UserID,
+        soft_failed: bool = False,
+        prev_event_ids: Optional[List[str]] = None,
+    ) -> str:
         """
         Create and send an event.
 
         Args:
-            soft_failed (bool): Whether to create a soft failed event or not
-            prev_event_ids (list[str]|None): Explicitly set the prev events,
+            soft_failed: Whether to create a soft failed event or not
+            prev_event_ids: Explicitly set the prev events,
                 or if None just use the default
 
         Returns:
-            str: The new event's ID.
+            The new event's ID.
         """
         event_creator = self.hs.get_event_creation_handler()
         requester = create_requester(user)
@@ -717,34 +738,7 @@ class HomeserverTestCase(TestCase):
 
         return event.event_id
 
-    def add_extremity(self, room_id, event_id):
-        """
-        Add the given event as an extremity to the room.
-        """
-        self.get_success(
-            self.hs.get_datastores().main.db_pool.simple_insert(
-                table="event_forward_extremities",
-                values={"room_id": room_id, "event_id": event_id},
-                desc="test_add_extremity",
-            )
-        )
-
-        self.hs.get_datastores().main.get_latest_event_ids_in_room.invalidate(
-            (room_id,)
-        )
-
-    def attempt_wrong_password_login(self, username, password):
-        """Attempts to login as the user with the given password, asserting
-        that the attempt *fails*.
-        """
-        body = {"type": "m.login.password", "user": username, "password": password}
-
-        channel = self.make_request(
-            "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
-        )
-        self.assertEqual(channel.code, 403, channel.result)
-
-    def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+    def inject_room_member(self, room: str, user: str, membership: str) -> None:
         """
         Inject a membership event into a room.
 
@@ -804,7 +798,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
         path: str,
         content: Optional[JsonDict] = None,
         await_result: bool = True,
-        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
         client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """Make an inbound signed federation request to this server
@@ -837,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
             self.site,
             method=method,
             path=path,
-            content=content,
+            content=content or "",
             shorthand=False,
             await_result=await_result,
             custom_headers=custom_headers,
@@ -916,9 +910,6 @@ def override_config(extra_config):
     return decorator
 
 
-TV = TypeVar("TV")
-
-
 def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
     """A test decorator which will skip the decorated test unless a condition is set