summary refs log tree commit diff
path: root/tests/test_server.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2020-01-15 16:00:24 +0000
committerRichard van der Hoff <richard@matrix.org>2020-01-15 16:00:24 +0000
commit107f256cd8c8d72c7a36df7e426aded09c5657c9 (patch)
tree788f2b18ebab979aa87462c7b1759da2e376320a /tests/test_server.py
parentchangelog (diff)
parentImplement RedirectException (#6687) (diff)
downloadsynapse-107f256cd8c8d72c7a36df7e426aded09c5657c9.tar.xz
Merge branch 'develop' into rav/module_api_extensions
Diffstat (limited to 'tests/test_server.py')
-rw-r--r--tests/test_server.py79
1 files changed, 77 insertions, 2 deletions
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..0d57eed268 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
 from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
 
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+    DirectServeResource,
+    JsonResource,
+    wrap_html_request_handler,
+)
 from synapse.http.site import SynapseSite, logger
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock
@@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
 
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+    class TestResource(DirectServeResource):
+        callback = None
+
+        @wrap_html_request_handler
+        async def _async_render_GET(self, request):
+            return await self.callback(request)
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+
+    def test_good_response(self):
+        def callback(request):
+            request.write(b"response")
+            request.finish()
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        body = channel.result["body"]
+        self.assertEqual(body, b"response")
+
+    def test_redirect_exception(self):
+        """
+        If the callback raises a RedirectException, it is turned into a 30x
+        with the right location.
+        """
+
+        def callback(request, **kwargs):
+            raise RedirectException(b"/look/an/eagle", 301)
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"301")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+    def test_redirect_exception_with_cookie(self):
+        """
+        If the callback raises a RedirectException which sets a cookie, that is
+        returned too
+        """
+
+        def callback(request, **kwargs):
+            e = RedirectException(b"/no/over/there", 304)
+            e.cookies.append(b"session=yespls")
+            raise e
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"304")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/no/over/there"])
+        cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+        self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
 class SiteTestCase(unittest.HomeserverTestCase):
     def test_lose_connection(self):
         """