summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2020-11-13 22:39:09 +0000
committerRichard van der Hoff <richard@matrix.org>2020-11-15 23:09:03 +0000
commit9debe657a39a234d574e949ae8faf3f5ed027c09 (patch)
tree898ade91a556963701584ad0724d57aff45476ae /tests
parentpass a Site into RestHelper (diff)
downloadsynapse-9debe657a39a234d574e949ae8faf3f5ed027c09.tar.xz
pass a Site into make_request
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v1/utils.py31
-rw-r--r--tests/server.py16
-rw-r--r--tests/test_server.py40
-rw-r--r--tests/unittest.py1
4 files changed, 68 insertions, 20 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index dc789fbdaa..60e4b9b846 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -27,7 +27,7 @@ from twisted.web.server import Site
 
 from synapse.api.constants import Membership
 
-from tests.server import make_request, render
+from tests.server import FakeSite, make_request, render
 
 
 @attr.s
@@ -53,7 +53,11 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         request, channel = make_request(
-            self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8")
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            path,
+            json.dumps(content).encode("utf8"),
         )
         render(request, self.site.resource, self.hs.get_reactor())
 
@@ -126,7 +130,11 @@ class RestHelper:
         data.update(extra_data)
 
         request, channel = make_request(
-            self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
+            self.hs.get_reactor(),
+            self.site,
+            "PUT",
+            path,
+            json.dumps(data).encode("utf8"),
         )
 
         render(request, self.site.resource, self.hs.get_reactor())
@@ -159,7 +167,11 @@ class RestHelper:
             path = path + "?access_token=%s" % tok
 
         request, channel = make_request(
-            self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8")
+            self.hs.get_reactor(),
+            self.site,
+            "PUT",
+            path,
+            json.dumps(content).encode("utf8"),
         )
         render(request, self.site.resource, self.hs.get_reactor())
 
@@ -211,7 +223,9 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        request, channel = make_request(self.hs.get_reactor(), method, path, content)
+        request, channel = make_request(
+            self.hs.get_reactor(), self.site, method, path, content
+        )
 
         render(request, self.site.resource, self.hs.get_reactor())
 
@@ -297,7 +311,12 @@ class RestHelper:
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
         request, channel = make_request(
-            self.hs.get_reactor(), "POST", path, content=image_data, access_token=tok
+            self.hs.get_reactor(),
+            FakeSite(resource),
+            "POST",
+            path,
+            content=image_data,
+            access_token=tok,
         )
         request.requestHeaders.addRawHeader(
             b"Content-Length", str(image_length).encode("UTF-8")
diff --git a/tests/server.py b/tests/server.py
index 3dd2cfc072..b9ccde4962 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -21,6 +21,7 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http import unquote
 from twisted.web.http_headers import Headers
+from twisted.web.resource import IResource
 from twisted.web.server import Site
 
 from synapse.http.site import SynapseRequest
@@ -128,9 +129,21 @@ class FakeSite:
     site_tag = "test"
     access_logger = logging.getLogger("synapse.access.http.fake")
 
+    def __init__(self, resource: IResource):
+        """
+
+        Args:
+            resource: the resource to be used for rendering all requests
+        """
+        self._resource = resource
+
+    def getResourceFor(self, request):
+        return self._resource
+
 
 def make_request(
     reactor,
+    site: Site,
     method,
     path,
     content=b"",
@@ -145,6 +158,8 @@ def make_request(
     content, and return the Request and the Channel underneath.
 
     Args:
+        site: The twisted Site to associate with the Channel
+
         method (bytes/unicode): The HTTP request method ("verb").
         path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
         escaped UTF-8 & spaces and such).
@@ -181,7 +196,6 @@ def make_request(
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    site = FakeSite()
     channel = FakeChannel(site, reactor)
 
     req = request(channel)
diff --git a/tests/test_server.py b/tests/test_server.py
index 655c918a15..300d13ac95 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -26,6 +26,7 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import (
+    FakeSite,
     ThreadedMemoryReactorClock,
     make_request,
     render,
@@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase):
         )
 
         request, channel = make_request(
-            self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
         )
         render(request, res, self.reactor)
 
@@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"500")
@@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"500")
@@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"403")
@@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
             "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
         )
 
-        request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"400")
@@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase):
         )
 
         # The path was registered as GET, but this is a HEAD request.
-        request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+        request, channel = make_request(
+            self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo"
+        )
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"200")
@@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase):
 
     def _make_request(self, method, path):
         """Create a request from the method/path and return a channel with the response."""
-        request, channel = make_request(self.reactor, method, path, shorthand=False)
-        request.prepath = []  # This doesn't get set properly by make_request.
-
         # Create a site and query for the resource.
         site = SynapseSite(
             "test",
@@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase):
             self.resource,
             "1.0",
         )
+
+        request, channel = make_request(
+            self.reactor, site, method, path, shorthand=False
+        )
+        request.prepath = []  # This doesn't get set properly by make_request.
+
         request.site = site
         resource = site.getResourceFor(request)
 
@@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"GET", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"200")
@@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"GET", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"301")
@@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"GET", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"304")
@@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         res = WrapHtmlRequestHandlerTests.TestResource()
         res.callback = callback
 
-        request, channel = make_request(self.reactor, b"HEAD", b"/path")
+        request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
         render(request, res, self.reactor)
 
         self.assertEqual(channel.result["code"], b"200")
diff --git a/tests/unittest.py b/tests/unittest.py
index 0a24c2f6b2..8c7979a7c0 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -434,6 +434,7 @@ class HomeserverTestCase(TestCase):
 
         return make_request(
             self.reactor,
+            self.site,
             method,
             path,
             content,