summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/test_login.py5
-rw-r--r--tests/server.py26
-rw-r--r--tests/test_server.py9
3 files changed, 28 insertions, 12 deletions
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 3fb77fd9dd..2b1e44381b 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -969,9 +969,8 @@ class CASTestCase(unittest.HomeserverTestCase):
         # Test that the response is HTML.
         self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
-        for header in channel.result.get("headers", []):
-            if header[0] == b"Content-Type":
-                content_type_header_value = header[1].decode("utf8")
+        for header in channel.headers.getRawHeaders("Content-Type", []):
+            content_type_header_value = header
 
         self.assertTrue(content_type_header_value.startswith("text/html"))
 
diff --git a/tests/server.py b/tests/server.py
index 85602e6953..3e377585ce 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -198,17 +198,35 @@ class FakeChannel:
     def headers(self) -> Headers:
         if not self.result:
             raise Exception("No result yet.")
-        h = Headers()
-        for i in self.result["headers"]:
-            h.addRawHeader(*i)
+
+        h = self.result["headers"]
+        assert isinstance(h, Headers)
         return h
 
     def writeHeaders(
-        self, version: bytes, code: bytes, reason: bytes, headers: Headers
+        self,
+        version: bytes,
+        code: bytes,
+        reason: bytes,
+        headers: Union[Headers, List[Tuple[bytes, bytes]]],
     ) -> None:
         self.result["version"] = version
         self.result["code"] = code
         self.result["reason"] = reason
+
+        if isinstance(headers, list):
+            # Support prior to Twisted 24.7.0rc1
+            new_headers = Headers()
+            for k, v in headers:
+                assert isinstance(k, bytes), f"key is not of type bytes: {k!r}"
+                assert isinstance(v, bytes), f"value is not of type bytes: {v!r}"
+                new_headers.addRawHeader(k, v)
+            headers = new_headers
+
+        assert isinstance(
+            headers, Headers
+        ), f"headers are of the wrong type: {headers!r}"
+
         self.result["headers"] = headers
 
     def write(self, data: bytes) -> None:
diff --git a/tests/test_server.py b/tests/test_server.py
index 0910ea5f28..9ff2589497 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -392,8 +392,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         )
 
         self.assertEqual(channel.code, 301)
-        headers = channel.result["headers"]
-        location_headers = [v for k, v in headers if k == b"Location"]
+        location_headers = channel.headers.getRawHeaders(b"Location", [])
         self.assertEqual(location_headers, [b"/look/an/eagle"])
 
     def test_redirect_exception_with_cookie(self) -> None:
@@ -415,10 +414,10 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         )
 
         self.assertEqual(channel.code, 304)
-        headers = channel.result["headers"]
-        location_headers = [v for k, v in headers if k == b"Location"]
+        headers = channel.headers
+        location_headers = headers.getRawHeaders(b"Location", [])
         self.assertEqual(location_headers, [b"/no/over/there"])
-        cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+        cookies_headers = headers.getRawHeaders(b"Set-Cookie", [])
         self.assertEqual(cookies_headers, [b"session=yespls"])
 
     def test_head_request(self) -> None: