diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 3fb77fd9dd..2b1e44381b 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -969,9 +969,8 @@ class CASTestCase(unittest.HomeserverTestCase):
# Test that the response is HTML.
self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = ""
- for header in channel.result.get("headers", []):
- if header[0] == b"Content-Type":
- content_type_header_value = header[1].decode("utf8")
+ for header in channel.headers.getRawHeaders("Content-Type", []):
+ content_type_header_value = header
self.assertTrue(content_type_header_value.startswith("text/html"))
diff --git a/tests/server.py b/tests/server.py
index f1cd0f76be..38ca095073 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -198,17 +198,35 @@ class FakeChannel:
def headers(self) -> Headers:
if not self.result:
raise Exception("No result yet.")
- h = Headers()
- for i in self.result["headers"]:
- h.addRawHeader(*i)
+
+ h = self.result["headers"]
+ assert isinstance(h, Headers)
return h
def writeHeaders(
- self, version: bytes, code: bytes, reason: bytes, headers: Headers
+ self,
+ version: bytes,
+ code: bytes,
+ reason: bytes,
+ headers: Union[Headers, List[Tuple[bytes, bytes]]],
) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
+
+ if isinstance(headers, list):
+ # Support prior to Twisted 24.7.0rc1
+ new_headers = Headers()
+ for k, v in headers:
+ assert isinstance(k, bytes), f"key is not of type bytes: {k!r}"
+ assert isinstance(v, bytes), f"value is not of type bytes: {v!r}"
+ new_headers.addRawHeader(k, v)
+ headers = new_headers
+
+ assert isinstance(
+ headers, Headers
+ ), f"headers are of the wrong type: {headers!r}"
+
self.result["headers"] = headers
def write(self, data: bytes) -> None:
diff --git a/tests/test_server.py b/tests/test_server.py
index 0910ea5f28..9ff2589497 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -392,8 +392,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
)
self.assertEqual(channel.code, 301)
- headers = channel.result["headers"]
- location_headers = [v for k, v in headers if k == b"Location"]
+ location_headers = channel.headers.getRawHeaders(b"Location", [])
self.assertEqual(location_headers, [b"/look/an/eagle"])
def test_redirect_exception_with_cookie(self) -> None:
@@ -415,10 +414,10 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
)
self.assertEqual(channel.code, 304)
- headers = channel.result["headers"]
- location_headers = [v for k, v in headers if k == b"Location"]
+ headers = channel.headers
+ location_headers = headers.getRawHeaders(b"Location", [])
self.assertEqual(location_headers, [b"/no/over/there"])
- cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+ cookies_headers = headers.getRawHeaders(b"Set-Cookie", [])
self.assertEqual(cookies_headers, [b"session=yespls"])
def test_head_request(self) -> None:
|