summary refs log tree commit diff
path: root/tests/test_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_server.py')
-rw-r--r--tests/test_server.py160
1 files changed, 158 insertions, 2 deletions
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..e9a43b1e45 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,13 @@ from twisted.test.proto_helpers import AccumulatingProtocol
 from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
 
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+    DirectServeResource,
+    JsonResource,
+    OptionsResource,
+    wrap_html_request_handler,
+)
 from synapse.http.site import SynapseSite, logger
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock
@@ -164,6 +169,157 @@ class JsonResourceTests(unittest.TestCase):
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
 
+class OptionsResourceTests(unittest.TestCase):
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+
+        class DummyResource(Resource):
+            isLeaf = True
+
+            def render(self, request):
+                return request.path
+
+        # Setup a resource with some children.
+        self.resource = OptionsResource()
+        self.resource.putChild(b"res", DummyResource())
+
+    def _make_request(self, method, path):
+        """Create a request from the method/path and return a channel with the response."""
+        request, channel = make_request(self.reactor, method, path, shorthand=False)
+        request.prepath = []  # This doesn't get set properly by make_request.
+
+        # Create a site and query for the resource.
+        site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+        request.site = site
+        resource = site.getResourceFor(request)
+
+        # Finally, render the resource and return the channel.
+        render(request, resource, self.reactor)
+        return channel
+
+    def test_unknown_options_request(self):
+        """An OPTIONS requests to an unknown URL still returns 200 OK."""
+        channel = self._make_request(b"OPTIONS", b"/foo/")
+        self.assertEqual(channel.result["code"], b"200")
+        self.assertEqual(channel.result["body"], b"{}")
+
+        # Ensure the correct CORS headers have been added
+        self.assertTrue(
+            channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+            "has CORS Origin header",
+        )
+        self.assertTrue(
+            channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+            "has CORS Methods header",
+        )
+        self.assertTrue(
+            channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+            "has CORS Headers header",
+        )
+
+    def test_known_options_request(self):
+        """An OPTIONS requests to an known URL still returns 200 OK."""
+        channel = self._make_request(b"OPTIONS", b"/res/")
+        self.assertEqual(channel.result["code"], b"200")
+        self.assertEqual(channel.result["body"], b"{}")
+
+        # Ensure the correct CORS headers have been added
+        self.assertTrue(
+            channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+            "has CORS Origin header",
+        )
+        self.assertTrue(
+            channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+            "has CORS Methods header",
+        )
+        self.assertTrue(
+            channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+            "has CORS Headers header",
+        )
+
+    def test_unknown_request(self):
+        """A non-OPTIONS request to an unknown URL should 404."""
+        channel = self._make_request(b"GET", b"/foo/")
+        self.assertEqual(channel.result["code"], b"404")
+
+    def test_known_request(self):
+        """A non-OPTIONS request to an known URL should query the proper resource."""
+        channel = self._make_request(b"GET", b"/res/")
+        self.assertEqual(channel.result["code"], b"200")
+        self.assertEqual(channel.result["body"], b"/res/")
+
+
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+    class TestResource(DirectServeResource):
+        callback = None
+
+        @wrap_html_request_handler
+        async def _async_render_GET(self, request):
+            return await self.callback(request)
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+
+    def test_good_response(self):
+        def callback(request):
+            request.write(b"response")
+            request.finish()
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        body = channel.result["body"]
+        self.assertEqual(body, b"response")
+
+    def test_redirect_exception(self):
+        """
+        If the callback raises a RedirectException, it is turned into a 30x
+        with the right location.
+        """
+
+        def callback(request, **kwargs):
+            raise RedirectException(b"/look/an/eagle", 301)
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"301")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+    def test_redirect_exception_with_cookie(self):
+        """
+        If the callback raises a RedirectException which sets a cookie, that is
+        returned too
+        """
+
+        def callback(request, **kwargs):
+            e = RedirectException(b"/no/over/there", 304)
+            e.cookies.append(b"session=yespls")
+            raise e
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"304")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/no/over/there"])
+        cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+        self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
 class SiteTestCase(unittest.HomeserverTestCase):
     def test_lose_connection(self):
         """