summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v1/utils.py10
-rw-r--r--tests/server.py11
-rw-r--r--tests/storage/test_client_ips.py13
3 files changed, 23 insertions, 11 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index afaf9f7b85..1b2d0497a6 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -296,10 +296,12 @@ class RestHelper:
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         request, channel = make_request(
-            self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
-        )
-        request.requestHeaders.addRawHeader(
-            b"Content-Length", str(image_length).encode("UTF-8")
+            self.hs.get_reactor(),
+            "POST",
+            path,
+            content=image_data,
+            access_token=tok,
+            custom_headers=[(b"Content-Length", str(image_length))],
         )
         request.render(resource)
         self.hs.get_reactor().pump([100])
diff --git a/tests/server.py b/tests/server.py
index 3dd2cfc072..ef03109a6c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable
+from typing import Callable, Iterable, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -139,6 +139,9 @@ def make_request(
     shorthand=True,
     federation_auth_origin=None,
     content_is_form=False,
+    custom_headers: Optional[
+        Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+    ] = None,
 ):
     """
     Make a web request using the given method and path, feed it the
@@ -157,6 +160,8 @@ def make_request(
         content_is_form: Whether the content is URL encoded form data. Adds the
             'Content-Type': 'application/x-www-form-urlencoded' header.
 
+        custom_headers: (name, value) pairs to add as request headers
+
     Returns:
         Tuple[synapse.http.site.SynapseRequest, channel]
     """
@@ -211,6 +216,10 @@ def make_request(
             # Assume the body is JSON
             req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
 
+    if custom_headers:
+        for k, v in custom_headers:
+            req.requestHeaders.addRawHeader(k, v)
+
     req.requestReceived(method, path, b"1.1")
 
     return req, channel
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index e96ca1c8ca..efca43ec78 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
 from synapse.rest.client.v1 import login
 
 from tests import unittest
+from tests.server import make_request
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
@@ -408,17 +409,17 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
         # Advance to a known time
         self.reactor.advance(123456 - self.reactor.seconds())
 
-        request, channel = self.make_request(
+        headers1 = {b"User-Agent": b"Mozzila pizza"}
+        headers1.update(headers)
+
+        request, channel = make_request(
+            self.reactor,
             "GET",
             "/_matrix/client/r0/admin/users/" + self.user_id,
             access_token=access_token,
+            custom_headers=headers1.items(),
             **make_request_args,
         )
-        request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
-
-        # Add the optional headers
-        for h, v in headers.items():
-            request.requestHeaders.addRawHeader(h, v)
         self.render(request)
 
         # Advance so the save loop occurs