summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-11-16 18:22:24 +0000
committerGitHub <noreply@github.com>2020-11-16 18:22:24 +0000
commit3dc1871219f845954a4b7d31fc06739831d67d2e (patch)
treea54654d07bc289d36fea381102c26e6c63d3e0b5
parentMove `wait_until_result` into `FakeChannel` (#8758) (diff)
parentfixup test (diff)
downloadsynapse-3dc1871219f845954a4b7d31fc06739831d67d2e.tar.xz
Merge pull request #8757 from matrix-org/rav/pass_site_to_make_request
Pass a Site into `make_request`
-rw-r--r--changelog.d/8757.misc1
-rw-r--r--tests/app/test_frontend_proxy.py13
-rw-r--r--tests/app/test_openid_listener.py17
-rw-r--r--tests/http/test_additional_resource.py13
-rw-r--r--tests/replication/test_client_reader_shard.py29
-rw-r--r--tests/replication/test_multi_media_repo.py10
-rw-r--r--tests/replication/test_sharded_event_persister.py42
-rw-r--r--tests/rest/admin/test_admin.py18
-rw-r--r--tests/rest/admin/test_media.py13
-rw-r--r--tests/rest/client/test_consent.py28
-rw-r--r--tests/rest/client/v1/utils.py36
-rw-r--r--tests/rest/client/v2_alpha/test_account.py14
-rw-r--r--tests/rest/media/v1/test_media_storage.py17
-rw-r--r--tests/server.py18
-rw-r--r--tests/storage/test_client_ips.py1
-rw-r--r--tests/test_server.py40
-rw-r--r--tests/unittest.py6
17 files changed, 228 insertions, 88 deletions
diff --git a/changelog.d/8757.misc b/changelog.d/8757.misc
new file mode 100644
index 0000000000..54502e9b90
--- /dev/null
+++ b/changelog.d/8757.misc
@@ -0,0 +1 @@
+Refactor test utilities for injecting HTTP requests.
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 4a301b84e1..0bac7995e8 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -15,6 +15,7 @@
 
 from synapse.app.generic_worker import GenericWorkerServer
 
+from tests.server import make_request, render
 from tests.unittest import HomeserverTestCase
 
 
@@ -55,10 +56,10 @@ class FrontendProxyTests(HomeserverTestCase):
         # Grab the resource from the site that was told to listen
         self.assertEqual(len(self.reactor.tcpServers), 1)
         site = self.reactor.tcpServers[0][1]
-        self.resource = site.resource.children[b"_matrix"].children[b"client"]
+        resource = site.resource.children[b"_matrix"].children[b"client"]
 
-        request, channel = self.make_request("PUT", "presence/a/status")
-        self.render(request)
+        request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+        render(request, resource, self.reactor)
 
         # 400 + unrecognised, because nothing is registered
         self.assertEqual(channel.code, 400)
@@ -77,10 +78,10 @@ class FrontendProxyTests(HomeserverTestCase):
         # Grab the resource from the site that was told to listen
         self.assertEqual(len(self.reactor.tcpServers), 1)
         site = self.reactor.tcpServers[0][1]
-        self.resource = site.resource.children[b"_matrix"].children[b"client"]
+        resource = site.resource.children[b"_matrix"].children[b"client"]
 
-        request, channel = self.make_request("PUT", "presence/a/status")
-        self.render(request)
+        request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+        render(request, resource, self.reactor)
 
         # 401, because the stub servlet still checks authentication
         self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index c2b10d2c70..1292145890 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
 from synapse.app.homeserver import SynapseHomeServer
 from synapse.config.server import parse_listener_def
 
+from tests.server import make_request, render
 from tests.unittest import HomeserverTestCase
 
 
@@ -66,16 +67,16 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
         # Grab the resource from the site that was told to listen
         site = self.reactor.tcpServers[0][1]
         try:
-            self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+            resource = site.resource.children[b"_matrix"].children[b"federation"]
         except KeyError:
             if expectation == "no_resource":
                 return
             raise
 
-        request, channel = self.make_request(
-            "GET", "/_matrix/federation/v1/openid/userinfo"
+        request, channel = make_request(
+            self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
         )
-        self.render(request)
+        render(request, resource, self.reactor)
 
         self.assertEqual(channel.code, 401)
 
@@ -115,15 +116,15 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
         # Grab the resource from the site that was told to listen
         site = self.reactor.tcpServers[0][1]
         try:
-            self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+            resource = site.resource.children[b"_matrix"].children[b"federation"]
         except KeyError:
             if expectation == "no_resource":
                 return
             raise
 
-        request, channel = self.make_request(
-            "GET", "/_matrix/federation/v1/openid/userinfo"
+        request, channel = make_request(
+            self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
         )
-        self.render(request)
+        render(request, resource, self.reactor)
 
         self.assertEqual(channel.code, 401)
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 62d36c2906..e835512a41 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -17,6 +17,7 @@
 from synapse.http.additional_resource import AdditionalResource
 from synapse.http.server import respond_with_json
 
+from tests.server import FakeSite, make_request, render
 from tests.unittest import HomeserverTestCase
 
 
@@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):
 
     def test_async(self):
         handler = _AsyncTestCustomEndpoint({}, None).handle_request
-        self.resource = AdditionalResource(self.hs, handler)
+        resource = AdditionalResource(self.hs, handler)
 
-        request, channel = self.make_request("GET", "/")
-        self.render(request)
+        request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+        render(request, resource, self.reactor)
 
         self.assertEqual(request.code, 200)
         self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
 
     def test_sync(self):
         handler = _SyncTestCustomEndpoint({}, None).handle_request
-        self.resource = AdditionalResource(self.hs, handler)
+        resource = AdditionalResource(self.hs, handler)
 
-        request, channel = self.make_request("GET", "/")
-        self.render(request)
+        request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+        render(request, resource, self.reactor)
 
         self.assertEqual(request.code, 200)
         self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 86c03fd89c..90172bd377 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel
+from tests.server import FakeChannel, make_request
 
 logger = logging.getLogger(__name__)
 
@@ -46,8 +46,11 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         """Test that registration works when using a single client reader worker.
         """
         worker_hs = self.make_worker_hs("synapse.app.client_reader")
+        site = self._hs_to_site[worker_hs]
 
-        request_1, channel_1 = self.make_request(
+        request_1, channel_1 = make_request(
+            self.reactor,
+            site,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
@@ -59,8 +62,12 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
-        request_2, channel_2 = self.make_request(
-            "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+        request_2, channel_2 = make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": "m.login.dummy"}},
         )  # type: SynapseRequest, FakeChannel
         self.render_on_worker(worker_hs, request_2)
         self.assertEqual(request_2.code, 200)
@@ -74,7 +81,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
         worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
 
-        request_1, channel_1 = self.make_request(
+        site_1 = self._hs_to_site[worker_hs_1]
+        request_1, channel_1 = make_request(
+            self.reactor,
+            site_1,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
@@ -86,8 +96,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
-        request_2, channel_2 = self.make_request(
-            "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+        site_2 = self._hs_to_site[worker_hs_2]
+        request_2, channel_2 = make_request(
+            self.reactor,
+            site_2,
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": "m.login.dummy"}},
         )  # type: SynapseRequest, FakeChannel
         self.render_on_worker(worker_hs_2, request_2)
         self.assertEqual(request_2.code, 200)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 77c261dbf7..a9ac4aeec1 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -28,7 +28,7 @@ from synapse.server import HomeServer
 
 from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.server import FakeChannel, FakeTransport
+from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
 
 logger = logging.getLogger(__name__)
 
@@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
             The channel for the *client* request and the *outbound* request for
             the media which the caller should respond to.
         """
-
-        request, channel = self.make_request(
+        resource = hs.get_media_repository_resource().children[b"download"]
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(resource),
             "GET",
             "/{}/{}".format(target, media_id),
             shorthand=False,
             access_token=self.access_token,
         )
-        request.render(hs.get_media_repository_resource().children[b"download"])
+        request.render(resource)
         self.pump()
 
         clients = self.reactor.tcpClients
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 82cf033d4e..2820dd622f 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import sync
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
 from tests.utils import USE_POSTGRES_FOR_TESTS
 
 logger = logging.getLogger(__name__)
@@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         sync_hs = self.make_worker_hs(
             "synapse.app.generic_worker", {"worker_name": "sync"},
         )
+        sync_hs_site = self._hs_to_site[sync_hs]
 
         # Specially selected room IDs that get persisted on different workers.
         room_id1 = "!foo:test"
@@ -178,7 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Do an initial sync so that we're up to date.
-        request, channel = self.make_request("GET", "/sync", access_token=access_token)
+        request, channel = make_request(
+            self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
+        )
         self.render_on_worker(sync_hs, request)
         next_batch = channel.json_body["next_batch"]
 
@@ -203,8 +207,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Check that syncing still gets the new event, despite the gap in the
         # stream IDs.
-        request, channel = self.make_request(
-            "GET", "/sync?since={}".format(next_batch), access_token=access_token
+        request, channel = make_request(
+            self.reactor,
+            sync_hs_site,
+            "GET",
+            "/sync?since={}".format(next_batch),
+            access_token=access_token,
         )
         self.render_on_worker(sync_hs, request)
 
@@ -230,7 +238,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
         first_event_in_room2 = response["event_id"]
 
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            sync_hs_site,
             "GET",
             "/sync?since={}".format(vector_clock_token),
             access_token=access_token,
@@ -254,8 +264,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
         self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
 
-        request, channel = self.make_request(
-            "GET", "/sync?since={}".format(next_batch), access_token=access_token
+        request, channel = make_request(
+            self.reactor,
+            sync_hs_site,
+            "GET",
+            "/sync?since={}".format(next_batch),
+            access_token=access_token,
         )
         self.render_on_worker(sync_hs, request)
 
@@ -269,7 +283,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         # Paginating back in the first room should not produce any results, as
         # no events have happened in it. This tests that we are correctly
         # filtering results based on the vector clock portion.
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            sync_hs_site,
             "GET",
             "/rooms/{}/messages?from={}&to={}&dir=b".format(
                 room_id1, prev_batch1, vector_clock_token
@@ -281,7 +297,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Paginating back on the second room should produce the first event
         # again. This tests that pagination isn't completely broken.
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            sync_hs_site,
             "GET",
             "/rooms/{}/messages?from={}&to={}&dir=b".format(
                 room_id2, prev_batch2, vector_clock_token
@@ -295,7 +313,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Paginating forwards should give the same results
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            sync_hs_site,
             "GET",
             "/rooms/{}/messages?from={}&to={}&dir=f".format(
                 room_id1, vector_clock_token, prev_batch1
@@ -305,7 +325,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.render_on_worker(sync_hs, request)
         self.assertListEqual([], channel.json_body["chunk"])
 
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            sync_hs_site,
             "GET",
             "/rooms/{}/messages?from={}&to={}&dir=f".format(
                 room_id2, vector_clock_token, prev_batch2,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6804f9337f..9e4b0bca53 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -30,6 +30,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import groups
 
 from tests import unittest
+from tests.server import FakeSite, make_request
 
 
 class VersionTestCase(unittest.HomeserverTestCase):
@@ -222,8 +223,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
     def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
         """Ensure a piece of media is quarantined when trying to access it."""
-        request, channel = self.make_request(
-            "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(self.download_resource),
+            "GET",
+            server_and_media_id,
+            shorthand=False,
+            access_token=admin_user_tok,
         )
         request.render(self.download_resource)
         self.pump(1.0)
@@ -287,7 +293,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         server_name, media_id = server_name_and_media_id.split("/")
 
         # Attempt to access the media
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(self.download_resource),
             "GET",
             server_name_and_media_id,
             shorthand=False,
@@ -462,7 +470,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
 
         # Attempt to access each piece of media
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(self.download_resource),
             "GET",
             server_and_media_id_2,
             shorthand=False,
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 721fa1ed51..36e07f1b36 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room
 from synapse.rest.media.v1.filepath import MediaFilePaths
 
 from tests import unittest
+from tests.server import FakeSite, make_request
 
 
 class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@@ -124,7 +125,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         self.assertEqual(server_name, self.server_name)
 
         # Attempt to access media
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(download_resource),
             "GET",
             server_and_media_id,
             shorthand=False,
@@ -161,7 +164,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
         )
 
         # Attempt to access media
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(download_resource),
             "GET",
             server_and_media_id,
             shorthand=False,
@@ -535,7 +540,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
         media_id = server_and_media_id.split("/")[1]
         local_path = self.filepaths.local_media_filepath(media_id)
 
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(download_resource),
             "GET",
             server_and_media_id,
             shorthand=False,
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 6803b372ac..2931859f25 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.rest.consent import consent_resource
 
 from tests import unittest
-from tests.server import render
+from tests.server import FakeSite, make_request, render
 
 
 class ConsentResourceTestCase(unittest.HomeserverTestCase):
@@ -61,7 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
     def test_render_public_consent(self):
         """You can observe the terms form without specifying a user"""
         resource = consent_resource.ConsentResource(self.hs)
-        request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
+        request, channel = make_request(
+            self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
+        )
         render(request, resource, self.reactor)
         self.assertEqual(channel.code, 200)
 
@@ -81,8 +83,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
             uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
             + "&u=user"
         )
-        request, channel = self.make_request(
-            "GET", consent_uri, access_token=access_token, shorthand=False
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(resource),
+            "GET",
+            consent_uri,
+            access_token=access_token,
+            shorthand=False,
         )
         render(request, resource, self.reactor)
         self.assertEqual(channel.code, 200)
@@ -92,7 +99,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(consented, "False")
 
         # POST to the consent page, saying we've agreed
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(resource),
             "POST",
             consent_uri + "&v=" + version,
             access_token=access_token,
@@ -103,8 +112,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
 
         # Fetch the consent page, to get the consent version -- it should have
         # changed
-        request, channel = self.make_request(
-            "GET", consent_uri, access_token=access_token, shorthand=False
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(resource),
+            "GET",
+            consent_uri,
+            access_token=access_token,
+            shorthand=False,
         )
         render(request, resource, self.reactor)
         self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 1b2d0497a6..900852f85b 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -23,10 +23,11 @@ from typing import Any, Dict, Optional
 import attr
 
 from twisted.web.resource import Resource
+from twisted.web.server import Site
 
 from synapse.api.constants import Membership
 
-from tests.server import make_request, render
+from tests.server import FakeSite, make_request, render
 
 
 @attr.s
@@ -36,7 +37,7 @@ class RestHelper:
     """
 
     hs = attr.ib()
-    resource = attr.ib()
+    site = attr.ib(type=Site)
     auth_user_id = attr.ib()
 
     def create_room_as(
@@ -52,9 +53,13 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         request, channel = make_request(
-            self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            path,
+            json.dumps(content).encode("utf8"),
         )
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert channel.result["code"] == b"%d" % expect_code, channel.result
         self.auth_user_id = temp_id
@@ -125,10 +130,14 @@ class RestHelper:
         data.update(extra_data)
 
         request, channel = make_request(
-            self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
+            self.hs.get_reactor(),
+            self.site,
+            "PUT",
+            path,
+            json.dumps(data).encode("utf8"),
         )
 
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -158,9 +167,13 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         request, channel = make_request(
-            self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
+            self.hs.get_reactor(),
+            self.site,
+            "PUT",
+            path,
+            json.dumps(content).encode("utf8"),
         )
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -210,9 +223,11 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        request, channel = make_request(self.hs.get_reactor(), method, path, content)
+        request, channel = make_request(
+            self.hs.get_reactor(), self.site, method, path, content
+        )
 
-        render(request, self.resource, self.hs.get_reactor())
+        render(request, self.site.resource, self.hs.get_reactor())
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -297,6 +312,7 @@ class RestHelper:
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         request, channel = make_request(
             self.hs.get_reactor(),
+            FakeSite(resource),
             "POST",
             path,
             content=image_data,
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 66ac4dbe85..94a627b0a6 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 
 from tests import unittest
+from tests.server import FakeSite, make_request
 from tests.unittest import override_config
 
 
@@ -255,9 +256,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         path = link.replace("https://example.com", "")
 
         # Load the password reset confirmation page
-        request, channel = self.make_request("GET", path, shorthand=False)
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(self.submit_token_resource),
+            "GET",
+            path,
+            shorthand=False,
+        )
         request.render(self.submit_token_resource)
         self.pump()
+
         self.assertEquals(200, channel.code, channel.result)
 
         # Now POST to the same endpoint, mimicking the same behaviour as clicking the
@@ -271,7 +279,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
                 form_args.append(arg)
 
         # Confirm the password reset
-        request, channel = self.make_request(
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(self.submit_token_resource),
             "POST",
             path,
             content=urlencode(form_args).encode("utf8"),
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 5f897d49cf..0fd31a0096 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
 from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 
 from tests import unittest
+from tests.server import FakeSite, make_request
 
 
 class MediaStorageTests(unittest.HomeserverTestCase):
@@ -227,7 +228,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
     def _req(self, content_disposition):
 
-        request, channel = self.make_request("GET", self.media_id, shorthand=False)
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(self.download_resource),
+            "GET",
+            self.media_id,
+            shorthand=False,
+        )
         request.render(self.download_resource)
         self.pump()
 
@@ -317,8 +324,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
 
     def _test_thumbnail(self, method, expected_body, expected_found):
         params = "?width=32&height=32&method=" + method
-        request, channel = self.make_request(
-            "GET", self.media_id + params, shorthand=False
+        request, channel = make_request(
+            self.reactor,
+            FakeSite(self.thumbnail_resource),
+            "GET",
+            self.media_id + params,
+            shorthand=False,
         )
         request.render(self.thumbnail_resource)
         self.pump()
diff --git a/tests/server.py b/tests/server.py
index 18cb8b2d72..5a1583a3e7 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -21,6 +21,7 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http import unquote
 from twisted.web.http_headers import Headers
+from twisted.web.resource import IResource
 from twisted.web.server import Site
 
 from synapse.http.site import SynapseRequest
@@ -147,9 +148,21 @@ class FakeSite:
     site_tag = "test"
     access_logger = logging.getLogger("synapse.access.http.fake")
 
+    def __init__(self, resource: IResource):
+        """
+
+        Args:
+            resource: the resource to be used for rendering all requests
+        """
+        self._resource = resource
+
+    def getResourceFor(self, request):
+        return self._resource
+
 
 def make_request(
     reactor,
+    site: Site,
     method,
     path,
     content=b"",
@@ -167,6 +180,8 @@ def make_request(
     content, and return the Request and the Channel underneath.
 
     Args:
+        site: The twisted Site to associate with the Channel
+
         method (bytes/unicode): The HTTP request method ("verb").
         path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
         escaped UTF-8 & spaces and such).
@@ -202,10 +217,11 @@ def make_request(
     if not path.startswith(b"/"):
         path = b"/" + path
 
+    if isinstance(content, dict):
+        content = json.dumps(content).encode("utf8")
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    site = FakeSite()
     channel = FakeChannel(site, reactor)
 
     req = request(channel)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index efca43ec78..583addb5b5 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -414,6 +414,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
 
         request, channel = make_request(
             self.reactor,
+            self.site,
             "GET",
             "/_matrix/client/r0/admin/users/" + self.user_id,
             access_token=access_token,
diff --git a/tests/test_server.py b/tests/test_server.py
index 655c918a15..300d13ac95 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -26,6 +26,7 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import (
+    FakeSite,
     ThreadedMemoryReactorClock,
     make_request,
     render,
@@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase):
         )
 
         request, channel = make_request(
-            self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
         )
         render(request, res, self.reactor)
 
@@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"500")
@@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"500")
@@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"403")
@@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"400")
@@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase):
         )
 
         # The path was registered as GET, but this is a HEAD request.
-        request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"200")
@@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase):
 
     def _make_request(self, method, path):
         """Create a request from the method/path and return a channel with the response."""
-        request, channel = make_request(self.reactor, method, path, shorthand=False)
-        request.prepath = []  # This doesn't get set properly by make_request.
-
         # Create a site and query for the resource.
         site = SynapseSite(
             "test",
@@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase):
             self.resource,
             "1.0",
         )
+
+        request, channel = make_request(
+            self.reactor, site, method, path, shorthand=False
+        )
+        request.prepath = []  # This doesn't get set properly by make_request.
+
         request.site = site
         resource = site.getResourceFor(request)
 
@@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"GET", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"200")
@@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"GET", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"301")
@@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"GET", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"304")
@@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"HEAD", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"200")
diff --git a/tests/unittest.py b/tests/unittest.py
index c630760e51..e39cb8dec9 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase):
 
         from tests.rest.client.v1.utils import RestHelper
 
-        self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+        self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
 
         if hasattr(self, "user_id"):
             if self.hijack_auth:
@@ -425,11 +425,9 @@ class HomeserverTestCase(TestCase):
         Returns:
             Tuple[synapse.http.site.SynapseRequest, channel]
         """
-        if isinstance(content, dict):
-            content = json.dumps(content).encode("utf8")
-
         return make_request(
             self.reactor,
+            self.site,
             method,
             path,
             content,