diff --git a/changelog.d/8757.misc b/changelog.d/8757.misc
new file mode 100644
index 0000000000..54502e9b90
--- /dev/null
+++ b/changelog.d/8757.misc
@@ -0,0 +1 @@
+Refactor test utilities for injecting HTTP requests.
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index 4a301b84e1..0bac7995e8 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer
+from tests.server import make_request, render
from tests.unittest import HomeserverTestCase
@@ -55,10 +56,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = site.resource.children[b"_matrix"].children[b"client"]
+ resource = site.resource.children[b"_matrix"].children[b"client"]
- request, channel = self.make_request("PUT", "presence/a/status")
- self.render(request)
+ request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+ render(request, resource, self.reactor)
# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
@@ -77,10 +78,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
- self.resource = site.resource.children[b"_matrix"].children[b"client"]
+ resource = site.resource.children[b"_matrix"].children[b"client"]
- request, channel = self.make_request("PUT", "presence/a/status")
- self.render(request)
+ request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
+ render(request, resource, self.reactor)
# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index c2b10d2c70..1292145890 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def
+from tests.server import make_request, render
from tests.unittest import HomeserverTestCase
@@ -66,16 +67,16 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+ resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/openid/userinfo"
+ request, channel = make_request(
+ self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
- self.render(request)
+ render(request, resource, self.reactor)
self.assertEqual(channel.code, 401)
@@ -115,15 +116,15 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
- self.resource = site.resource.children[b"_matrix"].children[b"federation"]
+ resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise
- request, channel = self.make_request(
- "GET", "/_matrix/federation/v1/openid/userinfo"
+ request, channel = make_request(
+ self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
- self.render(request)
+ render(request, resource, self.reactor)
self.assertEqual(channel.code, 401)
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
index 62d36c2906..e835512a41 100644
--- a/tests/http/test_additional_resource.py
+++ b/tests/http/test_additional_resource.py
@@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
+from tests.server import FakeSite, make_request, render
from tests.unittest import HomeserverTestCase
@@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
- self.resource = AdditionalResource(self.hs, handler)
+ resource = AdditionalResource(self.hs, handler)
- request, channel = self.make_request("GET", "/")
- self.render(request)
+ request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ render(request, resource, self.reactor)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
- self.resource = AdditionalResource(self.hs, handler)
+ resource = AdditionalResource(self.hs, handler)
- request, channel = self.make_request("GET", "/")
- self.render(request)
+ request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
+ render(request, resource, self.reactor)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 86c03fd89c..90172bd377 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel
+from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__)
@@ -46,8 +46,11 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
- request_1, channel_1 = self.make_request(
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
@@ -59,8 +62,12 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ request_2, channel_2 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
@@ -74,7 +81,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
- request_1, channel_1 = self.make_request(
+ site_1 = self._hs_to_site[worker_hs_1]
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
@@ -86,8 +96,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = self.make_request(
- "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ site_2 = self._hs_to_site[worker_hs_2]
+ request_2, channel_2 = make_request(
+ self.reactor,
+ site_2,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 77c261dbf7..48b574ccbe 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -28,7 +28,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.server import FakeChannel, FakeTransport
+from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__)
@@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
-
- request, channel = self.make_request(
+ resource = hs.get_media_repository_resource().children[b"download"]
+ _, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
"GET",
"/{}/{}".format(target, media_id),
shorthand=False,
access_token=self.access_token,
+ await_result=False,
)
- request.render(hs.get_media_repository_resource().children[b"download"])
self.pump()
clients = self.reactor.tcpClients
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 82cf033d4e..2820dd622f 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"},
)
+ sync_hs_site = self._hs_to_site[sync_hs]
# Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test"
@@ -178,7 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Do an initial sync so that we're up to date.
- request, channel = self.make_request("GET", "/sync", access_token=access_token)
+ request, channel = make_request(
+ self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
+ )
self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"]
@@ -203,8 +207,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
- request, channel = self.make_request(
- "GET", "/sync?since={}".format(next_batch), access_token=access_token
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
)
self.render_on_worker(sync_hs, request)
@@ -230,7 +238,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/sync?since={}".format(vector_clock_token),
access_token=access_token,
@@ -254,8 +264,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
- request, channel = self.make_request(
- "GET", "/sync?since={}".format(next_batch), access_token=access_token
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
+ "GET",
+ "/sync?since={}".format(next_batch),
+ access_token=access_token,
)
self.render_on_worker(sync_hs, request)
@@ -269,7 +283,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token
@@ -281,7 +297,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token
@@ -295,7 +313,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Paginating forwards should give the same results
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1
@@ -305,7 +325,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2,
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 6804f9337f..961a5732b3 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -30,6 +30,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
+from tests.server import FakeSite, make_request
class VersionTestCase(unittest.HomeserverTestCase):
@@ -222,11 +223,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
- request, channel = self.make_request(
- "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
+ "GET",
+ server_and_media_id,
+ shorthand=False,
+ access_token=admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Should be quarantined
self.assertEqual(
@@ -287,14 +291,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
"GET",
server_name_and_media_id,
shorthand=False,
access_token=non_admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Should be successful
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
@@ -462,14 +466,14 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
- request.render(self.download_resource)
- self.pump(1.0)
# Shouldn't be quarantined
self.assertEqual(
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 721fa1ed51..64b7aa53ee 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest
+from tests.server import FakeSite, make_request
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@@ -124,14 +125,14 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name)
# Attempt to access media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=self.admin_user_tok,
)
- request.render(download_resource)
- self.pump(1.0)
# Should be successful
self.assertEqual(
@@ -161,14 +162,14 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
)
# Attempt to access media
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=self.admin_user_tok,
)
- request.render(download_resource)
- self.pump(1.0)
self.assertEqual(
404,
channel.code,
@@ -535,14 +536,14 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id)
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=self.admin_user_tok,
)
- request.render(download_resource)
- self.pump(1.0)
if expect_success:
self.assertEqual(
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 6803b372ac..2931859f25 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource
from tests import unittest
-from tests.server import render
+from tests.server import FakeSite, make_request, render
class ConsentResourceTestCase(unittest.HomeserverTestCase):
@@ -61,7 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
- request, channel = self.make_request("GET", "/consent?v=1", shorthand=False)
+ request, channel = make_request(
+ self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
+ )
render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
@@ -81,8 +83,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user"
)
- request, channel = self.make_request(
- "GET", consent_uri, access_token=access_token, shorthand=False
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ consent_uri,
+ access_token=access_token,
+ shorthand=False,
)
render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
@@ -92,7 +99,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
@@ -103,8 +112,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Fetch the consent page, to get the consent version -- it should have
# changed
- request, channel = self.make_request(
- "GET", consent_uri, access_token=access_token, shorthand=False
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(resource),
+ "GET",
+ consent_uri,
+ access_token=access_token,
+ shorthand=False,
)
render(request, resource, self.reactor)
self.assertEqual(channel.code, 200)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 900852f85b..040a92d6f0 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -310,7 +310,7 @@ class RestHelper:
"""
image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
- request, channel = make_request(
+ _, channel = make_request(
self.hs.get_reactor(),
FakeSite(resource),
"POST",
@@ -319,8 +319,6 @@ class RestHelper:
access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))],
)
- request.render(resource)
- self.hs.get_reactor().pump([100])
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 66ac4dbe85..b871200909 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest
+from tests.server import FakeSite, make_request
from tests.unittest import override_config
@@ -255,9 +256,14 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
# Load the password reset confirmation page
- request, channel = self.make_request("GET", path, shorthand=False)
- request.render(self.submit_token_resource)
- self.pump()
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.submit_token_resource),
+ "GET",
+ path,
+ shorthand=False,
+ )
+
self.assertEquals(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
@@ -271,15 +277,15 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
form_args.append(arg)
# Confirm the password reset
- request, channel = self.make_request(
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.submit_token_resource),
"POST",
path,
content=urlencode(form_args).encode("utf8"),
shorthand=False,
content_is_form=True,
)
- request.render(self.submit_token_resource)
- self.pump()
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 5f897d49cf..2a3b2a8f27 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
+from tests.server import FakeSite, make_request
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -227,8 +228,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition):
- request, channel = self.make_request("GET", self.media_id, shorthand=False)
- request.render(self.download_resource)
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.download_resource),
+ "GET",
+ self.media_id,
+ shorthand=False,
+ await_result=False,
+ )
self.pump()
# We've made one fetch, to example.com, using the media URL, and asking
@@ -317,10 +324,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
- request, channel = self.make_request(
- "GET", self.media_id + params, shorthand=False
+ request, channel = make_request(
+ self.reactor,
+ FakeSite(self.thumbnail_resource),
+ "GET",
+ self.media_id + params,
+ shorthand=False,
+ await_result=False,
)
- request.render(self.thumbnail_resource)
self.pump()
headers = {
@@ -348,7 +359,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_NOT_FOUND",
- "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
- % method,
+ "error": "Not found [b'example.com', b'12345']",
},
)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index c00a7b9114..ccdc8c2ecf 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -133,13 +133,18 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver()
+ def create_test_resource(self):
+ return self.hs.get_media_repository_resource()
+
def test_cache_returns_correct_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -160,10 +165,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check the cache returns the correct response
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET", "preview_url?url=http://matrix.org", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# Check the cache response has the same content
self.assertEqual(channel.code, 200)
@@ -178,10 +181,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check the database cache returns the correct response
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET", "preview_url?url=http://matrix.org", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# Check the cache response has the same content
self.assertEqual(channel.code, 200)
@@ -201,9 +202,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -234,9 +237,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -267,9 +272,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
request, channel = self.make_request(
- "GET", "url_preview?url=http://matrix.org", shorthand=False
+ "GET",
+ "preview_url?url=http://matrix.org",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -298,9 +305,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -326,10 +335,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -349,10 +356,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
@@ -368,10 +373,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blacklisted IP addresses, accessed directly, are not spidered.
"""
request, channel = self.make_request(
- "GET", "url_preview?url=http://192.168.1.1", shorthand=False
+ "GET", "preview_url?url=http://192.168.1.1", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -389,10 +392,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blacklisted IP ranges, accessed directly, are not spidered.
"""
request, channel = self.make_request(
- "GET", "url_preview?url=http://1.1.1.2", shorthand=False
+ "GET", "preview_url?url=http://1.1.1.2", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 403)
self.assertEqual(
@@ -411,9 +412,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -446,10 +449,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
channel.json_body,
@@ -468,10 +469,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
# No requests made.
self.assertEqual(len(self.reactor.tcpClients), 0)
@@ -491,10 +490,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 502)
self.assertEqual(
@@ -510,10 +507,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
OPTIONS returns the OPTIONS.
"""
request, channel = self.make_request(
- "OPTIONS", "url_preview?url=http://example.com", shorthand=False
+ "OPTIONS", "preview_url?url=http://example.com", shorthand=False
)
- request.render(self.preview_url)
- self.pump()
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
@@ -525,9 +520,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Build and make a request to the server
request, channel = self.make_request(
- "GET", "url_preview?url=http://example.com", shorthand=False
+ "GET",
+ "preview_url?url=http://example.com",
+ shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
# Extract Synapse's tcp client
@@ -598,10 +595,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET",
- "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
@@ -663,10 +660,10 @@ class URLPreviewTests(unittest.HomeserverTestCase):
request, channel = self.make_request(
"GET",
- "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+ "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False,
+ await_result=False,
)
- request.render(self.preview_url)
self.pump()
client = self.reactor.tcpClients[0][2].buildProtocol(None)
diff --git a/tests/server.py b/tests/server.py
index d26a1dc441..de7cb1d8b3 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -171,16 +171,18 @@ def make_request(
shorthand=True,
federation_auth_origin=None,
content_is_form=False,
+ await_result: bool = True,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
):
"""
- Make a web request using the given method and path, feed it the
- content, and return the Request and the Channel underneath.
+ Make a web request using the given method, path and content, and render it
+
+ Returns the Request and the Channel underneath.
Args:
- site: The twisted Site to associate with the Channel
+ site: The twisted Site to use to render the request
method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
@@ -196,6 +198,10 @@ def make_request(
custom_headers: (name, value) pairs to add as request headers
+ await_result: whether to wait for the request to complete rendering. If true,
+ will pump the reactor until the the renderer tells the channel the request
+ is finished.
+
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
@@ -217,17 +223,17 @@ def make_request(
if not path.startswith(b"/"):
path = b"/" + path
+ if isinstance(content, dict):
+ content = json.dumps(content).encode("utf8")
if isinstance(content, str):
content = content.encode("utf8")
channel = FakeChannel(site, reactor)
req = request(channel)
- req.process = lambda: b""
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END)
- req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
req.requestHeaders.addRawHeader(
@@ -255,12 +261,14 @@ def make_request(
req.requestReceived(method, path, b"1.1")
+ if await_result:
+ channel.await_result()
+
return req, channel
def render(request, resource, clock):
- request.render(resource)
- request._channel.await_result()
+ pass
@implementer(IReactorPluggableNameResolver)
diff --git a/tests/unittest.py b/tests/unittest.py
index f0a421e605..9c7eca3b6e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -377,6 +377,7 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
+ await_result: bool = True,
) -> Tuple[SynapseRequest, FakeChannel]:
...
@@ -391,6 +392,7 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
+ await_result: bool = True,
) -> Tuple[T, FakeChannel]:
...
@@ -404,6 +406,7 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
+ await_result: bool = True,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -422,12 +425,13 @@ class HomeserverTestCase(TestCase):
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
+ await_result: whether to wait for the request to complete rendering. If
+ true (the default), will pump the test reactor until the the renderer
+ tells the channel the request is finished.
+
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
- if isinstance(content, dict):
- content = json.dumps(content).encode("utf8")
-
return make_request(
self.reactor,
self.site,
@@ -439,6 +443,7 @@ class HomeserverTestCase(TestCase):
shorthand,
federation_auth_origin,
content_is_form,
+ await_result,
)
def render(self, request):
|