diff --git a/tests/unittest.py b/tests/unittest.py
index 74db7c08f1..3eec9c4d5b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -140,7 +140,7 @@ class TestCase(unittest.TestCase):
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
- raise (type(e))("Assert error for '.{}':".format(key)) from e
+ raise (type(e))(f"Assert error for '.{key}':") from e
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
@@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase):
if not isinstance(deferred, Deferred):
return d
- results = [] # type: list
+ results: list = []
deferred.addBoth(results.append)
self.pump(by=by)
@@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase):
user_id = channel.json_body["user_id"]
return user_id
- def login(self, username, password, device_id=None):
+ def login(
+ self,
+ username,
+ password,
+ device_id=None,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
+ ):
"""
Log in a user, and get an access token. Requires the Login API be
registered.
@@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase):
body["device_id"] = device_id
channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+ "POST",
+ "/_matrix/client/r0/login",
+ json.dumps(body).encode("utf8"),
+ custom_headers=custom_headers,
)
self.assertEqual(channel.code, 200, channel.result)
|