summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py169
1 files changed, 110 insertions, 59 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 66b063f991..8715abd4d1 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -21,9 +21,10 @@ import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
+from synapse.http.server import respond_with_html
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -41,6 +42,10 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class MappingException(Exception):
+    """Used to catch errors when mapping the SAML2 response to a user."""
+
+
 @attr.s
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
@@ -68,6 +73,7 @@ class SamlHandler:
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
         self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
+        self._error_template = hs.config.sso_error_template
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -84,6 +90,25 @@ class SamlHandler:
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
 
+    def _render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Render the error template and respond to the request with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
     ) -> bytes:
@@ -134,49 +159,6 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
-        ip_address = self.hs.get_ip_from_request(request)
-
-        user_id, current_session = await self._map_saml_response_to_user(
-            resp_bytes, relay_state, user_agent, ip_address
-        )
-
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
-    async def _map_saml_response_to_user(
-        self,
-        resp_bytes: str,
-        client_redirect_url: str,
-        user_agent: str,
-        ip_address: str,
-    ) -> Tuple[str, Optional[Saml2SessionData]]:
-        """
-        Given a sample response, retrieve the cached session and user for it.
-
-        Args:
-            resp_bytes: The SAML response.
-            client_redirect_url: The redirect URL passed in by the client.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Returns:
-             Tuple of the user ID and SAML session associated with this response.
-
-        Raises:
-            SynapseError if there was a problem with the response.
-            RedirectException: some mapping providers may raise this if they need
-                to redirect to an interstitial page.
-        """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -189,12 +171,23 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            raise SynapseError(400, "Unexpected SAML2 login.")
+            self._render_error(
+                request, "unsolicited_response", "Unexpected SAML2 login."
+            )
+            return
         except Exception as e:
-            raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
+            self._render_error(
+                request,
+                "invalid_response",
+                "Unable to parse SAML2 response: %s." % (e,),
+            )
+            return
 
         if saml2_auth.not_signed:
-            raise SynapseError(400, "SAML2 response was not signed.")
+            self._render_error(
+                request, "unsigned_respond", "SAML2 response was not signed."
+            )
+            return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
         for assertion in saml2_auth.assertions:
@@ -213,15 +206,73 @@ class SamlHandler:
             saml2_auth.in_response_to, None
         )
 
+        # Ensure that the attributes of the logged in user meet the required
+        # attributes.
         for requirement in self._saml2_attribute_requirements:
-            _check_attribute_requirement(saml2_auth.ava, requirement)
+            if not _check_attribute_requirement(saml2_auth.ava, requirement):
+                self._render_error(
+                    request, "unauthorised", "You are not authorised to log in here."
+                )
+                return
+
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+            0
+        ].decode("ascii", "surrogateescape")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Call the mapper to register/login the user
+        try:
+            user_id = await self._map_saml_response_to_user(
+                saml2_auth, relay_state, user_agent, ip_address
+            )
+        except MappingException as e:
+            logger.exception("Could not map user")
+            self._render_error(request, "mapping_error", str(e))
+            return
+
+        # Complete the interactive auth session or the login.
+        if current_session and current_session.ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, current_session.ui_auth_session_id, request
+            )
+
+        else:
+            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: str,
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a SAML response, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+        """
 
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
 
         if not remote_user_id:
-            raise Exception("Failed to extract remote user id from SAML response")
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
 
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
@@ -235,7 +286,7 @@ class SamlHandler:
             )
             if registered_user_id is not None:
                 logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id, current_session
+                return registered_user_id
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -260,7 +311,7 @@ class SamlHandler:
                     await self._datastore.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
-                    return registered_user_id, current_session
+                    return registered_user_id
 
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
@@ -277,7 +328,7 @@ class SamlHandler:
 
                 localpart = attribute_dict.get("mxid_localpart")
                 if not localpart:
-                    raise Exception(
+                    raise MappingException(
                         "Error parsing SAML2 response: SAML mapping provider plugin "
                         "did not return a mxid_localpart value"
                     )
@@ -294,8 +345,8 @@ class SamlHandler:
             else:
                 # Unable to generate a username in 1000 iterations
                 # Break and return error to the user
-                raise SynapseError(
-                    500, "Unable to generate a Matrix ID from the SAML response"
+                raise MappingException(
+                    "Unable to generate a Matrix ID from the SAML response"
                 )
 
             logger.info("Mapped SAML user to local part %s", localpart)
@@ -310,7 +361,7 @@ class SamlHandler:
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
-            return registered_user_id, current_session
+            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@@ -323,11 +374,11 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
-def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
     values = ava.get(req.attribute, [])
     for v in values:
         if v == req.value:
-            return
+            return True
 
     logger.info(
         "SAML2 attribute %s did not match required value '%s' (was '%s')",
@@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
         req.value,
         values,
     )
-    raise AuthError(403, "You are not authorized to log in here.")
+    return False
 
 
 DOT_REPLACE_PATTERN = re.compile(
@@ -390,7 +441,7 @@ class DefaultSamlMappingProvider:
             return saml_response.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+            raise MappingException("'uid' not in SAML2 response")
 
     def saml_response_to_user_attributes(
         self,