diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 7cf782e2d6..1af5e5cee5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -28,11 +28,12 @@ from typing import (
MutableMapping,
Optional,
Tuple,
- Union,
+ overload,
)
from unittest.mock import patch
import attr
+from typing_extensions import Literal
from twisted.web.resource import Resource
from twisted.web.server import Site
@@ -55,6 +56,32 @@ class RestHelper:
site = attr.ib(type=Site)
auth_user_id = attr.ib()
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: Literal[200] = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> str:
+ ...
+
+ @overload
+ def create_room_as(
+ self,
+ room_creator: Optional[str] = ...,
+ is_public: Optional[bool] = ...,
+ room_version: Optional[str] = ...,
+ tok: Optional[str] = ...,
+ expect_code: int = ...,
+ extra_content: Optional[Dict] = ...,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
+ ) -> Optional[str]:
+ ...
+
def create_room_as(
self,
room_creator: Optional[str] = None,
@@ -64,7 +91,7 @@ class RestHelper:
expect_code: int = 200,
extra_content: Optional[Dict] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
- ) -> str:
+ ) -> Optional[str]:
"""
Create a room.
@@ -107,6 +134,8 @@ class RestHelper:
if expect_code == 200:
return channel.json_body["room_id"]
+ else:
+ return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -176,7 +205,7 @@ class RestHelper:
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
- expect_errcode: str = None,
+ expect_errcode: Optional[str] = None,
) -> None:
"""
Send a membership state event into a room.
@@ -260,9 +289,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
- custom_headers: Optional[
- Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
- ] = None,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -509,7 +536,7 @@ class RestHelper:
went.
"""
- cookies = {}
+ cookies: Dict[str, str] = {}
# if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id:
@@ -631,7 +658,13 @@ class RestHelper:
# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
- location = channel.headers.getRawHeaders("Location")[0]
+ def get_location(channel: FakeChannel) -> str:
+ location_values = channel.headers.getRawHeaders("Location")
+ # Keep mypy happy by asserting that location_values is nonempty
+ assert location_values
+ return location_values[0]
+
+ location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
@@ -645,7 +678,7 @@ class RestHelper:
assert channel.code == 302
channel.extract_cookies(cookies)
- return channel.headers.getRawHeaders("Location")[0]
+ return get_location(channel)
def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
|