diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..e9a43b1e45 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,13 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+ DirectServeResource,
+ JsonResource,
+ OptionsResource,
+ wrap_html_request_handler,
+)
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@@ -164,6 +169,157 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+class OptionsResourceTests(unittest.TestCase):
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ class DummyResource(Resource):
+ isLeaf = True
+
+ def render(self, request):
+ return request.path
+
+ # Setup a resource with some children.
+ self.resource = OptionsResource()
+ self.resource.putChild(b"res", DummyResource())
+
+ def _make_request(self, method, path):
+ """Create a request from the method/path and return a channel with the response."""
+ request, channel = make_request(self.reactor, method, path, shorthand=False)
+ request.prepath = [] # This doesn't get set properly by make_request.
+
+ # Create a site and query for the resource.
+ site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+ request.site = site
+ resource = site.getResourceFor(request)
+
+ # Finally, render the resource and return the channel.
+ render(request, resource, self.reactor)
+ return channel
+
+ def test_unknown_options_request(self):
+ """An OPTIONS requests to an unknown URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/foo/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_known_options_request(self):
+ """An OPTIONS requests to an known URL still returns 200 OK."""
+ channel = self._make_request(b"OPTIONS", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"{}")
+
+ # Ensure the correct CORS headers have been added
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Origin"),
+ "has CORS Origin header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Methods"),
+ "has CORS Methods header",
+ )
+ self.assertTrue(
+ channel.headers.hasHeader(b"Access-Control-Allow-Headers"),
+ "has CORS Headers header",
+ )
+
+ def test_unknown_request(self):
+ """A non-OPTIONS request to an unknown URL should 404."""
+ channel = self._make_request(b"GET", b"/foo/")
+ self.assertEqual(channel.result["code"], b"404")
+
+ def test_known_request(self):
+ """A non-OPTIONS request to an known URL should query the proper resource."""
+ channel = self._make_request(b"GET", b"/res/")
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertEqual(channel.result["body"], b"/res/")
+
+
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+ class TestResource(DirectServeResource):
+ callback = None
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self.callback(request)
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+
+ def test_good_response(self):
+ def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ body = channel.result["body"]
+ self.assertEqual(body, b"response")
+
+ def test_redirect_exception(self):
+ """
+ If the callback raises a RedirectException, it is turned into a 30x
+ with the right location.
+ """
+
+ def callback(request, **kwargs):
+ raise RedirectException(b"/look/an/eagle", 301)
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"301")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+ def test_redirect_exception_with_cookie(self):
+ """
+ If the callback raises a RedirectException which sets a cookie, that is
+ returned too
+ """
+
+ def callback(request, **kwargs):
+ e = RedirectException(b"/no/over/there", 304)
+ e.cookies.append(b"session=yespls")
+ raise e
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"GET", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"304")
+ headers = channel.result["headers"]
+ location_headers = [v for k, v in headers if k == b"Location"]
+ self.assertEqual(location_headers, [b"/no/over/there"])
+ cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+ self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self):
"""
|