summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-06-03 10:41:12 +0100
committerGitHub <noreply@github.com>2020-06-03 10:41:12 +0100
commit1bbc9e2df6cf9251460ca110918d876d3f50a379 (patch)
treea265286c7b9acf3c97147e1db33c35444d083431 /synapse
parentupdate grafana dashboard (diff)
downloadsynapse-1bbc9e2df6cf9251460ca110918d876d3f50a379.tar.xz
Clean up exception handling in SAML2ResponseResource (#7614)
* Expose `return_html_error`, and allow it to take a Jinja2 template instead of a raw string

* Clean up exception handling in SAML2ResponseResource

  * use the existing code in `return_html_error` instead of re-implementing it
    (giving it a jinja2 template rather than inventing a new form of template)

  * do the exception-catching in the REST layer rather than in the handler
    layer, to make sure we catch all exceptions.

Diffstat (limited to '')
-rw-r--r--synapse/config/saml2_config.py18
-rw-r--r--synapse/handlers/saml_handler.py41
-rw-r--r--synapse/http/server.py43
-rw-r--r--synapse/rest/saml2/response_resource.py26
4 files changed, 68 insertions, 60 deletions
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 726a27d7b2..38ec256984 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -15,8 +15,8 @@
 # limitations under the License.
 
 import logging
-import os
 
+import jinja2
 import pkg_resources
 
 from synapse.python_dependencies import DependencyException, check_requirements
@@ -167,9 +167,11 @@ class SAML2Config(Config):
         if not template_dir:
             template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
 
-        self.saml2_error_html_content = self.read_file(
-            os.path.join(template_dir, "saml_error.html"), "saml2_config.saml_error",
-        )
+        loader = jinja2.FileSystemLoader(template_dir)
+        # enable auto-escape here, to having to remember to escape manually in the
+        # template
+        env = jinja2.Environment(loader=loader, autoescape=True)
+        self.saml2_error_html_template = env.get_template("saml_error.html")
 
     def _default_saml_config_dict(
         self, required_attributes: set, optional_attributes: set
@@ -349,7 +351,13 @@ class SAML2Config(Config):
           # * HTML page to display to users if something goes wrong during the
           #   authentication process: 'saml_error.html'.
           #
-          #   This template doesn't currently need any variable to render.
+          #   When rendering, this template is given the following variables:
+          #     * code: an HTML error code corresponding to the error that is being
+          #       returned (typically 400 or 500)
+          #
+          #     * msg: a textual message describing the error.
+          #
+          #   The variables will automatically be HTML-escaped.
           #
           # You can see the default templates at:
           # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index de6ba4ab55..abecaa8313 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,11 +23,9 @@ from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
-from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
-from synapse.module_api.errors import RedirectException
 from synapse.types import (
     UserID,
     map_username_to_mxid_localpart,
@@ -80,8 +78,6 @@ class SamlHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
-        self._error_html_content = hs.config.saml2_error_html_content
-
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
     ) -> bytes:
@@ -129,26 +125,9 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        try:
-            user_id, current_session = await self._map_saml_response_to_user(
-                resp_bytes, relay_state
-            )
-        except RedirectException:
-            # Raise the exception as per the wishes of the SAML module response
-            raise
-        except Exception as e:
-            # If decoding the response or mapping it to a user failed, then log the
-            # error and tell the user that something went wrong.
-            logger.error(e)
-
-            request.setResponseCode(400)
-            request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-            request.setHeader(
-                b"Content-Length", b"%d" % (len(self._error_html_content),)
-            )
-            request.write(self._error_html_content.encode("utf8"))
-            finish_request(request)
-            return
+        user_id, current_session = await self._map_saml_response_to_user(
+            resp_bytes, relay_state
+        )
 
         # Complete the interactive auth session or the login.
         if current_session and current_session.ui_auth_session_id:
@@ -171,6 +150,11 @@ class SamlHandler:
 
         Returns:
              Tuple of the user ID and SAML session associated with this response.
+
+        Raises:
+            SynapseError if there was a problem with the response.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
         """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
@@ -179,11 +163,9 @@ class SamlHandler:
                 outstanding=self._outstanding_requests_dict,
             )
         except Exception as e:
-            logger.warning("Exception parsing SAML2 response: %s", e)
             raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
 
         if saml2_auth.not_signed:
-            logger.warning("SAML2 response was not signed")
             raise SynapseError(400, "SAML2 response was not signed")
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
@@ -264,11 +246,10 @@ class SamlHandler:
 
                 localpart = attribute_dict.get("mxid_localpart")
                 if not localpart:
-                    logger.error(
-                        "SAML mapping provider plugin did not return a "
-                        "mxid_localpart object"
+                    raise Exception(
+                        "Error parsing SAML2 response: SAML mapping provider plugin "
+                        "did not return a mxid_localpart value"
                     )
-                    raise SynapseError(500, "Error parsing SAML2 response")
 
                 displayname = attribute_dict.get("displayname")
                 emails = attribute_dict.get("emails", [])
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 9cc2e2e154..2487a72171 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -21,13 +21,15 @@ import logging
 import types
 import urllib
 from io import BytesIO
+from typing import Awaitable, Callable, TypeVar, Union
 
+import jinja2
 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
 
 from twisted.internet import defer
 from twisted.python import failure
 from twisted.web import resource
-from twisted.web.server import NOT_DONE_YET
+from twisted.web.server import NOT_DONE_YET, Request
 from twisted.web.static import NoRangeStaticProducer
 from twisted.web.util import redirectTo
 
@@ -40,6 +42,7 @@ from synapse.api.errors import (
     SynapseError,
     UnrecognizedRequestError,
 )
+from synapse.http.site import SynapseRequest
 from synapse.logging.context import preserve_fn
 from synapse.logging.opentracing import trace_servlet
 from synapse.util.caches import intern_dict
@@ -130,7 +133,12 @@ def wrap_json_request_handler(h):
     return wrap_async_request_handler(wrapped_request_handler)
 
 
-def wrap_html_request_handler(h):
+TV = TypeVar("TV")
+
+
+def wrap_html_request_handler(
+    h: Callable[[TV, SynapseRequest], Awaitable]
+) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
     """Wraps a request handler method with exception handling.
 
     Also does the wrapping with request.processing as per wrap_async_request_handler.
@@ -141,20 +149,26 @@ def wrap_html_request_handler(h):
 
     async def wrapped_request_handler(self, request):
         try:
-            return await h(self, request)
+            await h(self, request)
         except Exception:
             f = failure.Failure()
-            return _return_html_error(f, request)
+            return_html_error(f, request, HTML_ERROR_TEMPLATE)
 
     return wrap_async_request_handler(wrapped_request_handler)
 
 
-def _return_html_error(f, request):
-    """Sends an HTML error page corresponding to the given failure
+def return_html_error(
+    f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+) -> None:
+    """Sends an HTML error page corresponding to the given failure.
+
+    Handles RedirectException and other CodeMessageExceptions (such as SynapseError)
 
     Args:
-        f (twisted.python.failure.Failure):
-        request (twisted.web.server.Request):
+        f: the error to report
+        request: the failing request
+        error_template: the HTML template. Can be either a string (with `{code}`,
+            `{msg}` placeholders), or a jinja2 template
     """
     if f.check(CodeMessageException):
         cme = f.value
@@ -174,7 +188,7 @@ def _return_html_error(f, request):
                 exc_info=(f.type, f.value, f.getTracebackObject()),
             )
     else:
-        code = http.client.INTERNAL_SERVER_ERROR
+        code = http.HTTPStatus.INTERNAL_SERVER_ERROR
         msg = "Internal server error"
 
         logger.error(
@@ -183,11 +197,16 @@ def _return_html_error(f, request):
             exc_info=(f.type, f.value, f.getTracebackObject()),
         )
 
-    body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
+    if isinstance(error_template, str):
+        body = error_template.format(code=code, msg=html.escape(msg))
+    else:
+        body = error_template.render(code=code, msg=msg)
+
+    body_bytes = body.encode("utf-8")
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-    request.setHeader(b"Content-Length", b"%i" % (len(body),))
-    request.write(body)
+    request.setHeader(b"Content-Length", b"%i" % (len(body_bytes),))
+    request.write(body_bytes)
     finish_request(request)
 
 
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index a545c13db7..75e58043b4 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,12 +13,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.python import failure
 
-from synapse.http.server import (
-    DirectServeResource,
-    finish_request,
-    wrap_html_request_handler,
-)
+from synapse.api.errors import SynapseError
+from synapse.http.server import DirectServeResource, return_html_error
 
 
 class SAML2ResponseResource(DirectServeResource):
@@ -28,20 +26,22 @@ class SAML2ResponseResource(DirectServeResource):
 
     def __init__(self, hs):
         super().__init__()
-        self._error_html_content = hs.config.saml2_error_html_content
         self._saml_handler = hs.get_saml_handler()
+        self._error_html_template = hs.config.saml2.saml2_error_html_template
 
     async def _async_render_GET(self, request):
         # We're not expecting any GET request on that resource if everything goes right,
         # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
         # In this case, just tell the user that something went wrong and they should
         # try to authenticate again.
-        request.setResponseCode(400)
-        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-        request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),))
-        request.write(self._error_html_content.encode("utf8"))
-        finish_request(request)
+        f = failure.Failure(
+            SynapseError(400, "Unexpected GET request on /saml2/authn_response")
+        )
+        return_html_error(f, request, self._error_html_template)
 
-    @wrap_html_request_handler
     async def _async_render_POST(self, request):
-        return await self._saml_handler.handle_saml_response(request)
+        try:
+            await self._saml_handler.handle_saml_response(request)
+        except Exception:
+            f = failure.Failure()
+            return_html_error(f, request, self._error_html_template)