summary refs log tree commit diff
path: root/tests/rest/client/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/utils.py')
-rw-r--r--tests/rest/client/utils.py36
1 files changed, 25 insertions, 11 deletions
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 46cd5f70a8..28663826fc 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -41,6 +41,7 @@ from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests.server import FakeChannel, FakeSite, make_request
@@ -48,15 +49,15 @@ from tests.test_utils import FakeResponse
 from tests.test_utils.html_parsers import TestHtmlParser
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class RestHelper:
     """Contains extra helper functions to quickly and clearly perform a given
     REST action, which isn't the focus of the test.
     """
 
-    hs = attr.ib()
-    site = attr.ib(type=Site)
-    auth_user_id = attr.ib()
+    hs: HomeServer
+    site: Site
+    auth_user_id: Optional[str]
 
     @overload
     def create_room_as(
@@ -145,7 +146,7 @@ class RestHelper:
 
     def invite(
         self,
-        room: Optional[str] = None,
+        room: str,
         src: Optional[str] = None,
         targ: Optional[str] = None,
         expect_code: int = HTTPStatus.OK,
@@ -216,7 +217,7 @@ class RestHelper:
 
     def leave(
         self,
-        room: Optional[str] = None,
+        room: str,
         user: Optional[str] = None,
         expect_code: int = HTTPStatus.OK,
         tok: Optional[str] = None,
@@ -230,14 +231,22 @@ class RestHelper:
             expect_code=expect_code,
         )
 
-    def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None:
+    def ban(
+        self,
+        room: str,
+        src: str,
+        targ: str,
+        expect_code: int = HTTPStatus.OK,
+        tok: Optional[str] = None,
+    ) -> None:
         """A convenience helper: `change_membership` with `membership` preset to "ban"."""
         self.change_membership(
             room=room,
             src=src,
             targ=targ,
+            tok=tok,
             membership=Membership.BAN,
-            **kwargs,
+            expect_code=expect_code,
         )
 
     def change_membership(
@@ -378,7 +387,7 @@ class RestHelper:
         room_id: str,
         event_type: str,
         body: Optional[Dict[str, Any]],
-        tok: str,
+        tok: Optional[str],
         expect_code: int = HTTPStatus.OK,
         state_key: str = "",
         method: str = "GET",
@@ -458,7 +467,7 @@ class RestHelper:
         room_id: str,
         event_type: str,
         body: Dict[str, Any],
-        tok: str,
+        tok: Optional[str],
         expect_code: int = HTTPStatus.OK,
         state_key: str = "",
     ) -> JsonDict:
@@ -658,7 +667,12 @@ class RestHelper:
             (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
         ]
 
-        async def mock_req(method: str, uri: str, data=None, headers=None):
+        async def mock_req(
+            method: str,
+            uri: str,
+            data: Optional[dict] = None,
+            headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        ):
             (expected_uri, resp_obj) = expected_requests.pop(0)
             assert uri == expected_uri
             resp = FakeResponse(