diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 46cd5f70a8..28663826fc 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -41,6 +41,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.server import HomeServer
from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
@@ -48,15 +49,15 @@ from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
-@attr.s
+@attr.s(auto_attribs=True)
class RestHelper:
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
"""
- hs = attr.ib()
- site = attr.ib(type=Site)
- auth_user_id = attr.ib()
+ hs: HomeServer
+ site: Site
+ auth_user_id: Optional[str]
@overload
def create_room_as(
@@ -145,7 +146,7 @@ class RestHelper:
def invite(
self,
- room: Optional[str] = None,
+ room: str,
src: Optional[str] = None,
targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
@@ -216,7 +217,7 @@ class RestHelper:
def leave(
self,
- room: Optional[str] = None,
+ room: str,
user: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None,
@@ -230,14 +231,22 @@ class RestHelper:
expect_code=expect_code,
)
- def ban(self, room: str, src: str, targ: str, **kwargs: object) -> None:
+ def ban(
+ self,
+ room: str,
+ src: str,
+ targ: str,
+ expect_code: int = HTTPStatus.OK,
+ tok: Optional[str] = None,
+ ) -> None:
"""A convenience helper: `change_membership` with `membership` preset to "ban"."""
self.change_membership(
room=room,
src=src,
targ=targ,
+ tok=tok,
membership=Membership.BAN,
- **kwargs,
+ expect_code=expect_code,
)
def change_membership(
@@ -378,7 +387,7 @@ class RestHelper:
room_id: str,
event_type: str,
body: Optional[Dict[str, Any]],
- tok: str,
+ tok: Optional[str],
expect_code: int = HTTPStatus.OK,
state_key: str = "",
method: str = "GET",
@@ -458,7 +467,7 @@ class RestHelper:
room_id: str,
event_type: str,
body: Dict[str, Any],
- tok: str,
+ tok: Optional[str],
expect_code: int = HTTPStatus.OK,
state_key: str = "",
) -> JsonDict:
@@ -658,7 +667,12 @@ class RestHelper:
(TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
]
- async def mock_req(method: str, uri: str, data=None, headers=None):
+ async def mock_req(
+ method: str,
+ uri: str,
+ data: Optional[dict] = None,
+ headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ ):
(expected_uri, resp_obj) = expected_requests.pop(0)
assert uri == expected_uri
resp = FakeResponse(
|