diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..0d57eed268 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+ DirectServeResource,
+ JsonResource,
+ wrap_html_request_handler,
+)
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+ class TestResource(DirectServeResource):
+ callback = None
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self.callback(request)
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def test_good_response(self):
+ def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ body = channel.result["body"]
+ self.assertEqual(body, b"response")
+
+ def test_redirect_exception(self):
+ """
+ If the callback raises a RedirectException, it is turned into a 30x
+ with the right location.
+ """
+
+ def callback(request, **kwargs):
+ raise RedirectException(b"/look/an/eagle", 301)
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"301")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+ def test_redirect_exception_with_cookie(self):
+ """
+ If the callback raises a RedirectException which sets a cookie, that is
+ returned too
+ """
+
+ def callback(request, **kwargs):
+ e = RedirectException(b"/no/over/there", 304)
+ e.cookies.append(b"session=yespls")
+ raise e
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"304")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/no/over/there"])
+ cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+ self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self):
"""
|