summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6687.misc1
-rw-r--r--synapse/api/errors.py27
-rw-r--r--synapse/http/server.py13
-rw-r--r--tests/test_server.py79
4 files changed, 113 insertions, 7 deletions
diff --git a/changelog.d/6687.misc b/changelog.d/6687.misc
new file mode 100644
index 0000000000..deb0454602
--- /dev/null
+++ b/changelog.d/6687.misc
@@ -0,0 +1 @@
+Allow REST endpoint implementations to raise a RedirectException, which will redirect the user's browser to a given location.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 9e9844b47c..1c9456e583 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,13 +17,15 @@
 """Contains exceptions and error codes."""
 
 import logging
-from typing import Dict
+from typing import Dict, List
 
 from six import iteritems
 from six.moves import http_client
 
 from canonicaljson import json
 
+from twisted.web import http
+
 logger = logging.getLogger(__name__)
 
 
@@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError):
         self.msg = msg
 
 
+class RedirectException(CodeMessageException):
+    """A pseudo-error indicating that we want to redirect the client to a different
+    location
+
+    Attributes:
+        cookies: a list of set-cookies values to add to the response. For example:
+           b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
+    """
+
+    def __init__(self, location: bytes, http_code: int = http.FOUND):
+        """
+
+        Args:
+            location: the URI to redirect to
+            http_code: the HTTP response code
+        """
+        msg = "Redirect to %s" % (location.decode("utf-8"),)
+        super().__init__(code=http_code, msg=msg)
+        self.location = location
+
+        self.cookies = []  # type: List[bytes]
+
+
 class SynapseError(CodeMessageException):
     """A base exception type for matrix errors which have an errcode and error
     message (as well as an HTTP status code).
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 943d12c907..04bc2385a2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,8 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import cgi
 import collections
+import html
 import http.client
 import logging
 import types
@@ -36,6 +36,7 @@ import synapse.metrics
 from synapse.api.errors import (
     CodeMessageException,
     Codes,
+    RedirectException,
     SynapseError,
     UnrecognizedRequestError,
 )
@@ -153,14 +154,18 @@ def _return_html_error(f, request):
 
     Args:
         f (twisted.python.failure.Failure):
-        request (twisted.web.iweb.IRequest):
+        request (twisted.web.server.Request):
     """
     if f.check(CodeMessageException):
         cme = f.value
         code = cme.code
         msg = cme.msg
 
-        if isinstance(cme, SynapseError):
+        if isinstance(cme, RedirectException):
+            logger.info("%s redirect to %s", request, cme.location)
+            request.setHeader(b"location", cme.location)
+            request.cookies.extend(cme.cookies)
+        elif isinstance(cme, SynapseError):
             logger.info("%s SynapseError: %s - %s", request, code, msg)
         else:
             logger.error(
@@ -178,7 +183,7 @@ def _return_html_error(f, request):
             exc_info=(f.type, f.value, f.getTracebackObject()),
         )
 
-    body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
+    body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
     request.setHeader(b"Content-Length", b"%i" % (len(body),))
diff --git a/tests/test_server.py b/tests/test_server.py
index 98fef21d55..0d57eed268 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
 from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
 
-from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import JsonResource
+from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.http.server import (
+    DirectServeResource,
+    JsonResource,
+    wrap_html_request_handler,
+)
 from synapse.http.site import SynapseSite, logger
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock
@@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
 
+class WrapHtmlRequestHandlerTests(unittest.TestCase):
+    class TestResource(DirectServeResource):
+        callback = None
+
+        @wrap_html_request_handler
+        async def _async_render_GET(self, request):
+            return await self.callback(request)
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+
+    def test_good_response(self):
+        def callback(request):
+            request.write(b"response")
+            request.finish()
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        body = channel.result["body"]
+        self.assertEqual(body, b"response")
+
+    def test_redirect_exception(self):
+        """
+        If the callback raises a RedirectException, it is turned into a 30x
+        with the right location.
+        """
+
+        def callback(request, **kwargs):
+            raise RedirectException(b"/look/an/eagle", 301)
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"301")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/look/an/eagle"])
+
+    def test_redirect_exception_with_cookie(self):
+        """
+        If the callback raises a RedirectException which sets a cookie, that is
+        returned too
+        """
+
+        def callback(request, **kwargs):
+            e = RedirectException(b"/no/over/there", 304)
+            e.cookies.append(b"session=yespls")
+            raise e
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"GET", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"304")
+        headers = channel.result["headers"]
+        location_headers = [v for k, v in headers if k == b"Location"]
+        self.assertEqual(location_headers, [b"/no/over/there"])
+        cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
+        self.assertEqual(cookies_headers, [b"session=yespls"])
+
+
 class SiteTestCase(unittest.HomeserverTestCase):
     def test_lose_connection(self):
         """