summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 74db7c08f1..3eec9c4d5b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -140,7 +140,7 @@ class TestCase(unittest.TestCase):
             try:
                 self.assertEquals(attrs[key], getattr(obj, key))
             except AssertionError as e:
-                raise (type(e))("Assert error for '.{}':".format(key)) from e
+                raise (type(e))(f"Assert error for '.{key}':") from e
 
     def assert_dict(self, required, actual):
         """Does a partial assert of a dict.
@@ -520,7 +520,7 @@ class HomeserverTestCase(TestCase):
         if not isinstance(deferred, Deferred):
             return d
 
-        results = []  # type: list
+        results: list = []
         deferred.addBoth(results.append)
 
         self.pump(by=by)
@@ -594,7 +594,15 @@ class HomeserverTestCase(TestCase):
         user_id = channel.json_body["user_id"]
         return user_id
 
-    def login(self, username, password, device_id=None):
+    def login(
+        self,
+        username,
+        password,
+        device_id=None,
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
+    ):
         """
         Log in a user, and get an access token. Requires the Login API be
         registered.
@@ -605,7 +613,10 @@ class HomeserverTestCase(TestCase):
             body["device_id"] = device_id
 
         channel = self.make_request(
-            "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
+            "POST",
+            "/_matrix/client/r0/login",
+            json.dumps(body).encode("utf8"),
+            custom_headers=custom_headers,
         )
         self.assertEqual(channel.code, 200, channel.result)