summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/replication/_base.py4
-rw-r--r--tests/rest/admin/test_admin.py2
-rw-r--r--tests/rest/client/v1/utils.py4
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py2
-rw-r--r--tests/rest/test_health.py6
-rw-r--r--tests/rest/test_well_known.py6
-rw-r--r--tests/server.py11
-rw-r--r--tests/storage/test_client_ips.py13
-rw-r--r--tests/unittest.py18
9 files changed, 33 insertions, 33 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 5c633ac6df..bc56b13dcd 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -240,8 +240,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             lambda: self._handle_http_replication_attempt(self.hs, 8765),
         )
 
-    def create_test_json_resource(self):
-        """Overrides `HomeserverTestCase.create_test_json_resource`.
+    def create_test_resource(self):
+        """Overrides `HomeserverTestCase.create_test_resource`.
         """
         # We override this so that it automatically registers all the HTTP
         # replication servlets, without having to explicitly do that in all
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 64b6016729..9e4b0bca53 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -36,7 +36,7 @@ from tests.server import FakeSite, make_request
 class VersionTestCase(unittest.HomeserverTestCase):
     url = "/_synapse/admin/v1/server_version"
 
-    def create_test_json_resource(self):
+    def create_test_resource(self):
         resource = JsonResource(self.hs)
         VersionServlet(self.hs).register(resource)
         return resource
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 60e4b9b846..900852f85b 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -317,9 +317,7 @@ class RestHelper:
             path,
             content=image_data,
             access_token=tok,
-        )
-        request.requestHeaders.addRawHeader(
-            b"Content-Length", str(image_length).encode("UTF-8")
+            custom_headers=[(b"Content-Length", str(image_length))],
         )
         request.render(resource)
         self.hs.get_reactor().pump([100])
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6850c666be..6671cbd32d 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -41,7 +41,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
         self.http_client = Mock()
         return self.setup_test_homeserver(http_client=self.http_client)
 
-    def create_test_json_resource(self):
+    def create_test_resource(self):
         return create_resource_tree(
             {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
         )
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
index 2d021f6565..f4d06e2200 100644
--- a/tests/rest/test_health.py
+++ b/tests/rest/test_health.py
@@ -20,11 +20,9 @@ from tests import unittest
 
 
 class HealthCheckTests(unittest.HomeserverTestCase):
-    def setUp(self):
-        super().setUp()
-
+    def create_test_resource(self):
         # replace the JsonResource with a HealthResource.
-        self.resource = HealthResource()
+        return HealthResource()
 
     def test_health(self):
         request, channel = self.make_request("GET", "/health", shorthand=False)
diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py
index dcd65c2a50..a3746e7130 100644
--- a/tests/rest/test_well_known.py
+++ b/tests/rest/test_well_known.py
@@ -20,11 +20,9 @@ from tests import unittest
 
 
 class WellKnownTests(unittest.HomeserverTestCase):
-    def setUp(self):
-        super().setUp()
-
+    def create_test_resource(self):
         # replace the JsonResource with a WellKnownResource
-        self.resource = WellKnownResource(self.hs)
+        return WellKnownResource(self.hs)
 
     def test_well_known(self):
         self.hs.config.public_baseurl = "https://tesths"
diff --git a/tests/server.py b/tests/server.py
index a74fb3fc67..5850eadf3e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable
+from typing import Callable, Iterable, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -152,6 +152,9 @@ def make_request(
     shorthand=True,
     federation_auth_origin=None,
     content_is_form=False,
+    custom_headers: Optional[
+        Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+    ] = None,
 ):
     """
     Make a web request using the given method and path, feed it the
@@ -172,6 +175,8 @@ def make_request(
         content_is_form: Whether the content is URL encoded form data. Adds the
             'Content-Type': 'application/x-www-form-urlencoded' header.
 
+        custom_headers: (name, value) pairs to add as request headers
+
     Returns:
         Tuple[synapse.http.site.SynapseRequest, channel]
     """
@@ -227,6 +232,10 @@ def make_request(
             # Assume the body is JSON
             req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
 
+    if custom_headers:
+        for k, v in custom_headers:
+            req.requestHeaders.addRawHeader(k, v)
+
     req.requestReceived(method, path, b"1.1")
 
     return req, channel
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index e96ca1c8ca..efca43ec78 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -21,6 +21,7 @@ from synapse.http.site import XForwardedForRequest
 from synapse.rest.client.v1 import login
 
 from tests import unittest
+from tests.server import make_request
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
@@ -408,17 +409,17 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
         # Advance to a known time
         self.reactor.advance(123456 - self.reactor.seconds())
 
-        request, channel = self.make_request(
+        headers1 = {b"User-Agent": b"Mozzila pizza"}
+        headers1.update(headers)
+
+        request, channel = make_request(
+            self.reactor,
             "GET",
             "/_matrix/client/r0/admin/users/" + self.user_id,
             access_token=access_token,
+            custom_headers=headers1.items(),
             **make_request_args,
         )
-        request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")
-
-        # Add the optional headers
-        for h, v in headers.items():
-            request.requestHeaders.addRawHeader(h, v)
         self.render(request)
 
         # Advance so the save loop occurs
diff --git a/tests/unittest.py b/tests/unittest.py
index 3e656b7b12..e39cb8dec9 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
+from twisted.web.resource import Resource
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.config.homeserver import HomeServerConfig
@@ -239,10 +240,8 @@ class HomeserverTestCase(TestCase):
         if not isinstance(self.hs, HomeServer):
             raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
 
-        # Register the resources
-        self.resource = self.create_test_json_resource()
-
-        # create a site to wrap the resource.
+        # create the root resource, and a site to wrap it.
+        self.resource = self.create_test_resource()
         self.site = SynapseSite(
             logger_name="synapse.access.http.fake",
             site_tag=self.hs.config.server.server_name,
@@ -323,15 +322,12 @@ class HomeserverTestCase(TestCase):
         hs = self.setup_test_homeserver()
         return hs
 
-    def create_test_json_resource(self):
+    def create_test_resource(self) -> Resource:
         """
-        Create a test JsonResource, with the relevant servlets registerd to it
-
-        The default implementation calls each function in `servlets` to do the
-        registration.
+        Create a the root resource for the test server.
 
-        Returns:
-            JsonResource:
+        The default implementation creates a JsonResource and calls each function in
+        `servlets` to register servletes against it
         """
         resource = JsonResource(self.hs)