summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-07-27 13:18:41 -0400
committerGitHub <noreply@github.com>2022-07-27 17:18:41 +0000
commit922b771337f6d14a556fa761c783748f698e924b (patch)
tree7ee9ff2cdca63b9c912c818902009870ae93a90d /tests
parentImplement MSC3848: Introduce errcodes for specific event sending failures (#1... (diff)
downloadsynapse-922b771337f6d14a556fa761c783748f698e924b.tar.xz
Add missing type hints for tests.unittest. (#13397)
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_directory.py12
-rw-r--r--tests/rest/client/test_relations.py6
-rw-r--r--tests/rest/client/test_rooms.py2
-rw-r--r--tests/server.py11
-rw-r--r--tests/unittest.py86
5 files changed, 65 insertions, 52 deletions
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 53d49ca896..3b72c4c9d0 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -481,17 +481,13 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(
-        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
-    ) -> HomeServer:
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
         self.allowed_access_token = self.login(self.allowed_localpart, "pass")
 
         self.denied_user_id = self.register_user("denied", "pass")
         self.denied_access_token = self.login("denied", "pass")
 
-        return hs
-
     def test_denied_without_publication_permission(self) -> None:
         """
         Try to create a room, register an alias for it, and publish it,
@@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
     servlets = [directory.register_servlets, room.register_servlets]
 
-    def prepare(
-        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
-    ) -> HomeServer:
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         room_id = self.helper.create_room_as(self.user_id)
 
         channel = self.make_request(
@@ -588,8 +582,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.room_list_handler = hs.get_room_list_handler()
         self.directory_handler = hs.get_directory_handler()
 
-        return hs
-
     def test_disabling_room_list(self) -> None:
         self.room_list_handler.enable_room_list_search = True
         self.directory_handler.enable_room_list_search = True
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index ad03eee17b..d589f07314 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1060,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                     participated, bundled_aggregations.get("current_user_participated")
                 )
                 # The latest thread event has some fields that don't matter.
+                self.assertIn("latest_event", bundled_aggregations)
                 self.assert_dict(
                     {
                         "content": {
@@ -1072,7 +1073,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                         "sender": self.user2_id,
                         "type": "m.room.test",
                     },
-                    bundled_aggregations.get("latest_event"),
+                    bundled_aggregations["latest_event"],
                 )
 
             return assert_thread
@@ -1112,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
             self.assertEqual(2, bundled_aggregations.get("count"))
             self.assertTrue(bundled_aggregations.get("current_user_participated"))
             # The latest thread event has some fields that don't matter.
+            self.assertIn("latest_event", bundled_aggregations)
             self.assert_dict(
                 {
                     "content": {
@@ -1124,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                     "sender": self.user_id,
                     "type": "m.room.test",
                 },
-                bundled_aggregations.get("latest_event"),
+                bundled_aggregations["latest_event"],
             )
             # Check the unsigned field on the latest event.
             self.assert_dict(
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index c45cb32090..2272d55d84 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -496,7 +496,7 @@ class RoomStateTestCase(RoomBase):
 
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
         self.assertCountEqual(
-            [state_event["type"] for state_event in channel.json_body],
+            [state_event["type"] for state_event in channel.json_list],
             {
                 "m.room.create",
                 "m.room.power_levels",
diff --git a/tests/server.py b/tests/server.py
index df3f1564c9..9689e6a0cd 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -25,6 +25,7 @@ from typing import (
     Callable,
     Dict,
     Iterable,
+    List,
     MutableMapping,
     Optional,
     Tuple,
@@ -121,7 +122,15 @@ class FakeChannel:
 
     @property
     def json_body(self) -> JsonDict:
-        return json.loads(self.text_body)
+        body = json.loads(self.text_body)
+        assert isinstance(body, dict)
+        return body
+
+    @property
+    def json_list(self) -> List[JsonDict]:
+        body = json.loads(self.text_body)
+        assert isinstance(body, list)
+        return body
 
     @property
     def text_body(self) -> str:
diff --git a/tests/unittest.py b/tests/unittest.py
index 66ce92f4a6..bec4a3d023 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -28,6 +28,7 @@ from typing import (
     Generic,
     Iterable,
     List,
+    NoReturn,
     Optional,
     Tuple,
     Type,
@@ -39,7 +40,7 @@ from unittest.mock import Mock, patch
 import canonicaljson
 import signedjson.key
 import unpaddedbase64
-from typing_extensions import Protocol
+from typing_extensions import Concatenate, ParamSpec, Protocol
 
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
@@ -67,7 +68,7 @@ from synapse.logging.context import (
 from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
 from synapse.storage.keys import FetchKeyResult
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 
@@ -88,6 +89,10 @@ setup_logging()
 TV = TypeVar("TV")
 _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
 
+P = ParamSpec("P")
+R = TypeVar("R")
+S = TypeVar("S")
+
 
 class _TypedFailure(Generic[_ExcType], Protocol):
     """Extension to twisted.Failure, where the 'value' has a certain type."""
@@ -97,7 +102,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
         ...
 
 
-def around(target):
+def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
     """A CLOS-style 'around' modifier, which wraps the original method of the
     given instance with another piece of code.
 
@@ -106,11 +111,11 @@ def around(target):
         return orig(*args, **kwargs)
     """
 
-    def _around(code):
+    def _around(code: Callable[Concatenate[S, P], R]) -> None:
         name = code.__name__
         orig = getattr(target, name)
 
-        def new(*args, **kwargs):
+        def new(*args: P.args, **kwargs: P.kwargs) -> R:
             return code(orig, *args, **kwargs)
 
         setattr(target, name, new)
@@ -131,7 +136,7 @@ class TestCase(unittest.TestCase):
         level = getattr(method, "loglevel", getattr(self, "loglevel", None))
 
         @around(self)
-        def setUp(orig):
+        def setUp(orig: Callable[[], R]) -> R:
             # if we're not starting in the sentinel logcontext, then to be honest
             # all future bets are off.
             if current_context():
@@ -144,7 +149,7 @@ class TestCase(unittest.TestCase):
             if level is not None and old_level != level:
 
                 @around(self)
-                def tearDown(orig):
+                def tearDown(orig: Callable[[], R]) -> R:
                     ret = orig()
                     logging.getLogger().setLevel(old_level)
                     return ret
@@ -158,7 +163,7 @@ class TestCase(unittest.TestCase):
             return orig()
 
         @around(self)
-        def tearDown(orig):
+        def tearDown(orig: Callable[[], R]) -> R:
             ret = orig()
             # force a GC to workaround problems with deferreds leaking logcontexts when
             # they are GCed (see the logcontext docs)
@@ -167,7 +172,7 @@ class TestCase(unittest.TestCase):
 
             return ret
 
-    def assertObjectHasAttributes(self, attrs, obj):
+    def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
         """Asserts that the given object has each of the attributes given, and
         that the value of each matches according to assertEqual."""
         for key in attrs.keys():
@@ -178,12 +183,12 @@ class TestCase(unittest.TestCase):
             except AssertionError as e:
                 raise (type(e))(f"Assert error for '.{key}':") from e
 
-    def assert_dict(self, required, actual):
+    def assert_dict(self, required: dict, actual: dict) -> None:
         """Does a partial assert of a dict.
 
         Args:
-            required (dict): The keys and value which MUST be in 'actual'.
-            actual (dict): The test result. Extra keys will not be checked.
+            required: The keys and value which MUST be in 'actual'.
+            actual: The test result. Extra keys will not be checked.
         """
         for key in required:
             self.assertEqual(
@@ -191,31 +196,31 @@ class TestCase(unittest.TestCase):
             )
 
 
-def DEBUG(target):
+def DEBUG(target: TV) -> TV:
     """A decorator to set the .loglevel attribute to logging.DEBUG.
     Can apply to either a TestCase or an individual test method."""
-    target.loglevel = logging.DEBUG
+    target.loglevel = logging.DEBUG  # type: ignore[attr-defined]
     return target
 
 
-def INFO(target):
+def INFO(target: TV) -> TV:
     """A decorator to set the .loglevel attribute to logging.INFO.
     Can apply to either a TestCase or an individual test method."""
-    target.loglevel = logging.INFO
+    target.loglevel = logging.INFO  # type: ignore[attr-defined]
     return target
 
 
-def logcontext_clean(target):
+def logcontext_clean(target: TV) -> TV:
     """A decorator which marks the TestCase or method as 'logcontext_clean'
 
     ... ie, any logcontext errors should cause a test failure
     """
 
-    def logcontext_error(msg):
+    def logcontext_error(msg: str) -> NoReturn:
         raise AssertionError("logcontext error: %s" % (msg))
 
     patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
-    return patcher(target)
+    return patcher(target)  # type: ignore[call-overload]
 
 
 class HomeserverTestCase(TestCase):
@@ -255,7 +260,7 @@ class HomeserverTestCase(TestCase):
         method = getattr(self, methodName)
         self._extra_config = getattr(method, "_extra_config", None)
 
-    def setUp(self):
+    def setUp(self) -> None:
         """
         Set up the TestCase by calling the homeserver constructor, optionally
         hijacking the authentication system to return a fixed user, and then
@@ -306,7 +311,9 @@ class HomeserverTestCase(TestCase):
                     )
                 )
 
-                async def get_user_by_access_token(token=None, allow_guest=False):
+                async def get_user_by_access_token(
+                    token: Optional[str] = None, allow_guest: bool = False
+                ) -> JsonDict:
                     assert self.helper.auth_user_id is not None
                     return {
                         "user": UserID.from_string(self.helper.auth_user_id),
@@ -314,7 +321,11 @@ class HomeserverTestCase(TestCase):
                         "is_guest": False,
                     }
 
-                async def get_user_by_req(request, allow_guest=False):
+                async def get_user_by_req(
+                    request: SynapseRequest,
+                    allow_guest: bool = False,
+                    allow_expired: bool = False,
+                ) -> Requester:
                     assert self.helper.auth_user_id is not None
                     return create_requester(
                         UserID.from_string(self.helper.auth_user_id),
@@ -339,11 +350,11 @@ class HomeserverTestCase(TestCase):
         if hasattr(self, "prepare"):
             self.prepare(self.reactor, self.clock, self.hs)
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         # Reset to not use frozen dicts.
         events.USE_FROZEN_DICTS = False
 
-    def wait_on_thread(self, deferred, timeout=10):
+    def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
         """
         Wait until a Deferred is done, where it's waiting on a real thread.
         """
@@ -374,7 +385,7 @@ class HomeserverTestCase(TestCase):
             clock (synapse.util.Clock): The Clock, associated with the reactor.
 
         Returns:
-            A homeserver (synapse.server.HomeServer) suitable for testing.
+            A homeserver suitable for testing.
 
         Function to be overridden in subclasses.
         """
@@ -408,7 +419,7 @@ class HomeserverTestCase(TestCase):
             "/_synapse/admin": servlet_resource,
         }
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         """
         Get a default HomeServer config dict.
         """
@@ -421,7 +432,9 @@ class HomeserverTestCase(TestCase):
 
         return config
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         """
         Prepare for the test.  This involves things like mocking out parts of
         the homeserver, or building test data common across the whole test
@@ -519,7 +532,7 @@ class HomeserverTestCase(TestCase):
         config_obj.parse_config_dict(config, "", "")
         kwargs["config"] = config_obj
 
-        async def run_bg_updates():
+        async def run_bg_updates() -> None:
             with LoggingContext("run_bg_updates"):
                 self.get_success(stor.db_pool.updates.run_background_updates(False))
 
@@ -538,11 +551,7 @@ class HomeserverTestCase(TestCase):
         """
         self.reactor.pump([by] * 100)
 
-    def get_success(
-        self,
-        d: Awaitable[TV],
-        by: float = 0.0,
-    ) -> TV:
+    def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
         deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
         self.pump(by=by)
         return self.successResultOf(deferred)
@@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
     OTHER_SERVER_NAME = "other.example.com"
     OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
 
         # poke the other server's signing key into the key store, so that we don't
@@ -879,7 +888,7 @@ def _auth_header_for_request(
     )
 
 
-def override_config(extra_config):
+def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
     """A decorator which can be applied to test functions to give additional HS config
 
     For use
@@ -892,12 +901,13 @@ def override_config(extra_config):
                 ...
 
     Args:
-        extra_config(dict): Additional config settings to be merged into the default
+        extra_config: Additional config settings to be merged into the default
             config dict before instantiating the test homeserver.
     """
 
-    def decorator(func):
-        func._extra_config = extra_config
+    def decorator(func: TV) -> TV:
+        # This attribute is being defined.
+        func._extra_config = extra_config  # type: ignore[attr-defined]
         return func
 
     return decorator