summary refs log tree commit diff
path: root/synapse/http/server.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-06-03 10:41:12 +0100
committerGitHub <noreply@github.com>2020-06-03 10:41:12 +0100
commit1bbc9e2df6cf9251460ca110918d876d3f50a379 (patch)
treea265286c7b9acf3c97147e1db33c35444d083431 /synapse/http/server.py
parentupdate grafana dashboard (diff)
downloadsynapse-1bbc9e2df6cf9251460ca110918d876d3f50a379.tar.xz
Clean up exception handling in SAML2ResponseResource (#7614)
* Expose `return_html_error`, and allow it to take a Jinja2 template instead of a raw string

* Clean up exception handling in SAML2ResponseResource

  * use the existing code in `return_html_error` instead of re-implementing it
    (giving it a jinja2 template rather than inventing a new form of template)

  * do the exception-catching in the REST layer rather than in the handler
    layer, to make sure we catch all exceptions.

Diffstat (limited to 'synapse/http/server.py')
-rw-r--r--synapse/http/server.py43
1 files changed, 31 insertions, 12 deletions
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 9cc2e2e154..2487a72171 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,13 +21,15 @@ import logging
 import types
 import urllib
 from io import BytesIO
+from typing import Awaitable, Callable, TypeVar, Union
 
+import jinja2
 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
 
 from twisted.internet import defer
 from twisted.python import failure
 from twisted.web import resource
-from twisted.web.server import NOT_DONE_YET
+from twisted.web.server import NOT_DONE_YET, Request
 from twisted.web.static import NoRangeStaticProducer
 from twisted.web.util import redirectTo
 
@@ -40,6 +42,7 @@ from synapse.api.errors import (
     SynapseError,
     UnrecognizedRequestError,
 )
+from synapse.http.site import SynapseRequest
 from synapse.logging.context import preserve_fn
 from synapse.logging.opentracing import trace_servlet
 from synapse.util.caches import intern_dict
@@ -130,7 +133,12 @@ def wrap_json_request_handler(h):
     return wrap_async_request_handler(wrapped_request_handler)
 
 
-def wrap_html_request_handler(h):
+TV = TypeVar("TV")
+
+
+def wrap_html_request_handler(
+    h: Callable[[TV, SynapseRequest], Awaitable]
+) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
     """Wraps a request handler method with exception handling.
 
     Also does the wrapping with request.processing as per wrap_async_request_handler.
@@ -141,20 +149,26 @@ def wrap_html_request_handler(h):
 
     async def wrapped_request_handler(self, request):
         try:
-            return await h(self, request)
+            await h(self, request)
         except Exception:
             f = failure.Failure()
-            return _return_html_error(f, request)
+            return_html_error(f, request, HTML_ERROR_TEMPLATE)
 
     return wrap_async_request_handler(wrapped_request_handler)
 
 
-def _return_html_error(f, request):
-    """Sends an HTML error page corresponding to the given failure
+def return_html_error(
+    f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+) -> None:
+    """Sends an HTML error page corresponding to the given failure.
+
+    Handles RedirectException and other CodeMessageExceptions (such as SynapseError)
 
     Args:
-        f (twisted.python.failure.Failure):
-        request (twisted.web.server.Request):
+        f: the error to report
+        request: the failing request
+        error_template: the HTML template. Can be either a string (with `{code}`,
+            `{msg}` placeholders), or a jinja2 template
     """
     if f.check(CodeMessageException):
         cme = f.value
@@ -174,7 +188,7 @@ def _return_html_error(f, request):
                 exc_info=(f.type, f.value, f.getTracebackObject()),
             )
     else:
-        code = http.client.INTERNAL_SERVER_ERROR
+        code = http.HTTPStatus.INTERNAL_SERVER_ERROR
         msg = "Internal server error"
 
         logger.error(
@@ -183,11 +197,16 @@ def _return_html_error(f, request):
             exc_info=(f.type, f.value, f.getTracebackObject()),
         )
 
-    body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
+    if isinstance(error_template, str):
+        body = error_template.format(code=code, msg=html.escape(msg))
+    else:
+        body = error_template.render(code=code, msg=msg)
+
+    body_bytes = body.encode("utf-8")
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-    request.setHeader(b"Content-Length", b"%i" % (len(body),))
-    request.write(body)
+    request.setHeader(b"Content-Length", b"%i" % (len(body_bytes),))
+    request.write(body_bytes)
     finish_request(request)