diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 605b952316..7eba69642a 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -453,7 +453,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# stick the flows results in a dict by type
- flow_results = {} # type: Dict[str, Any]
+ flow_results: Dict[str, Any] = {}
for f in channel.json_body["flows"]:
flow_type = f["type"]
self.assertNotIn(
@@ -501,7 +501,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
p.close()
# there should be a link for each href
- returned_idps = [] # type: List[str]
+ returned_idps: List[str] = []
for link in p.links:
path, query = link.split("?", 1)
self.assertEqual(path, "pick_idp")
@@ -582,7 +582,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# ... and should have set a cookie including the redirect url
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
assert cookie_headers
- cookies = {} # type: Dict[str, str]
+ cookies: Dict[str, str] = {}
for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
@@ -874,9 +874,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(
- payload, secret, self.jwt_algorithm
- ) # type: Union[str, bytes]
+ result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm)
if isinstance(result, bytes):
return result.decode("ascii")
return result
@@ -1084,7 +1082,7 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
- result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
+ result: Union[bytes, str] = jwt.encode(payload, secret, "RS256")
if isinstance(result, bytes):
return result.decode("ascii")
return result
@@ -1272,7 +1270,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
- cookies = {} # type: Dict[str,str]
+ cookies: Dict[str, str] = {}
channel.extract_cookies(cookies)
self.assertIn("username_mapping_session", cookies)
session_id = cookies["username_mapping_session"]
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e94566ffd7..3df070c936 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1206,7 +1206,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/join".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/join",
content={"reason": reason},
access_token=self.second_tok,
)
@@ -1220,7 +1220,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/leave",
content={"reason": reason},
access_token=self.second_tok,
)
@@ -1234,7 +1234,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/kick".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/kick",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.second_tok,
)
@@ -1248,7 +1248,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/ban".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/ban",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1260,7 +1260,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/unban".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/unban",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1272,7 +1272,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/invite".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/invite",
content={"reason": reason, "user_id": self.second_user_id},
access_token=self.creator_tok,
)
@@ -1291,7 +1291,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
reason = "hello"
channel = self.make_request(
"POST",
- "/_matrix/client/r0/rooms/{}/leave".format(self.room_id),
+ f"/_matrix/client/r0/rooms/{self.room_id}/leave",
content={"reason": reason},
access_token=self.second_tok,
)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 69798e95c3..fc2d35596e 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -19,7 +19,7 @@ import json
import re
import time
import urllib.parse
-from typing import Any, Dict, Mapping, MutableMapping, Optional
+from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
from unittest.mock import patch
import attr
@@ -53,6 +53,9 @@ class RestHelper:
tok: str = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
) -> str:
"""
Create a room.
@@ -87,6 +90,7 @@ class RestHelper:
"POST",
path,
json.dumps(content).encode("utf8"),
+ custom_headers=custom_headers,
)
assert channel.result["code"] == b"%d" % expect_code, channel.result
@@ -175,14 +179,30 @@ class RestHelper:
self.auth_user_id = temp_id
- def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
+ def send(
+ self,
+ room_id,
+ body=None,
+ txn_id=None,
+ tok=None,
+ expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ):
if body is None:
body = "body_text_here"
content = {"msgtype": "m.text", "body": body}
return self.send_event(
- room_id, "m.room.message", content, txn_id, tok, expect_code
+ room_id,
+ "m.room.message",
+ content,
+ txn_id,
+ tok,
+ expect_code,
+ custom_headers=custom_headers,
)
def send_event(
@@ -193,6 +213,9 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -207,6 +230,7 @@ class RestHelper:
"PUT",
path,
json.dumps(content or {}).encode("utf8"),
+ custom_headers=custom_headers,
)
assert (
|