summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py166
1 files changed, 112 insertions, 54 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..e88fd59749 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
-from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
@@ -59,8 +58,6 @@ class SamlHandler(BaseHandler):
         super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
         self._saml_idp_entityid = hs.config.saml2_idp_entityid
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
 
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
         self._grandfathered_mxid_source_attribute = (
@@ -76,30 +73,46 @@ class SamlHandler(BaseHandler):
         )
 
         # identifier for the external_ids table
-        self._auth_provider_id = "saml"
+        self.idp_id = "saml"
+
+        # user-facing name of this auth provider
+        self.idp_name = "SAML"
+
+        # we do not currently support icons/brands for SAML auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+        self.idp_brand = None
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
-        # a lock on the mappings
-        self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
         self._sso_handler = hs.get_sso_handler()
+        self._sso_handler.register_identity_provider(self)
 
-    def handle_redirect_request(
-        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
-    ) -> bytes:
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
+            request: the incoming HTTP request
             client_redirect_url: the URL that we should redirect the
-                client to when everything is done
+                client to after login (or None for UI Auth).
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
             URL to redirect to
         """
+        if not client_redirect_url:
+            # Some SAML identity providers (e.g. Google) require a
+            # RelayState parameter on requests, so pass in a dummy redirect URL
+            # (which will never get used).
+            client_redirect_url = b"unused"
+
         reqid, info = self._saml_client.prepare_for_authenticate(
             entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
@@ -120,7 +133,7 @@ class SamlHandler(BaseHandler):
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
     async def handle_saml_response(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_matrix/saml2/authn_response
+        """Handle an incoming request to /_synapse/client/saml2/authn_response
 
         Args:
             request: the incoming request from the browser. We'll
@@ -167,6 +180,29 @@ class SamlHandler(BaseHandler):
             return
 
         logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+        await self._handle_authn_response(request, saml2_auth, relay_state)
+
+    async def _handle_authn_response(
+        self,
+        request: SynapseRequest,
+        saml2_auth: saml2.response.AuthnResponse,
+        relay_state: str,
+    ) -> None:
+        """Handle an AuthnResponse, having parsed it from the request params
+
+        Assumes that the signature on the response object has been checked. Maps
+        the user onto an MXID, registering them if necessary, and returns a response
+        to the browser.
+
+        Args:
+            request: the incoming request from the browser. We'll respond to it with an
+                HTML page or a redirect
+            saml2_auth: the parsed AuthnResponse object
+            relay_state: the RelayState query param, which encodes the URI to rediret
+               back to
+        """
+
         for assertion in saml2_auth.assertions:
             # kibana limits the length of a log field, whereas this is all rather
             # useful, so split it up.
@@ -183,6 +219,24 @@ class SamlHandler(BaseHandler):
             saml2_auth.in_response_to, None
         )
 
+        # first check if we're doing a UIA
+        if current_session and current_session.ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+            except MappingException as e:
+                logger.exception("Failed to extract remote user id from SAML response")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self.idp_id,
+                remote_user_id,
+                current_session.ui_auth_session_id,
+                request,
+            )
+
+        # otherwise, we're handling a login request.
+
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
         for requirement in self._saml2_attribute_requirements:
@@ -192,63 +246,39 @@ class SamlHandler(BaseHandler):
                 )
                 return
 
-        # Pull out the user-agent and IP from the request.
-        user_agent = request.get_user_agent("")
-        ip_address = self.hs.get_ip_from_request(request)
-
         # Call the mapper to register/login the user
         try:
-            user_id = await self._map_saml_response_to_user(
-                saml2_auth, relay_state, user_agent, ip_address
-            )
+            await self._complete_saml_login(saml2_auth, request, relay_state)
         except MappingException as e:
             logger.exception("Could not map user")
             self._sso_handler.render_error(request, "mapping_error", str(e))
-            return
-
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
 
-    async def _map_saml_response_to_user(
+    async def _complete_saml_login(
         self,
         saml2_auth: saml2.response.AuthnResponse,
+        request: SynapseRequest,
         client_redirect_url: str,
-        user_agent: str,
-        ip_address: str,
-    ) -> str:
+    ) -> None:
         """
-        Given a SAML response, retrieve the user ID for it and possibly register the user.
+        Given a SAML response, complete the login flow
+
+        Retrieves the remote user ID, registers the user if necessary, and serves
+        a redirect back to the client with a login-token.
 
         Args:
             saml2_auth: The parsed SAML2 response.
+            request: The request to respond to
             client_redirect_url: The redirect URL passed in by the client.
-            user_agent: The user agent of the client making the request.
-            ip_address: The IP address of the client making the request.
-
-        Returns:
-             The user ID associated with this response.
 
         Raises:
             MappingException if there was a problem mapping the response to a user.
             RedirectException: some mapping providers may raise this if they need
                 to redirect to an interstitial page.
         """
-
-        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+        remote_user_id = self._remote_id_from_saml_response(
             saml2_auth, client_redirect_url
         )
 
-        if not remote_user_id:
-            raise MappingException(
-                "Failed to extract remote user id from SAML response"
-            )
-
         async def saml_response_to_remapped_user_attributes(
             failures: int,
         ) -> UserAttributes:
@@ -294,16 +324,44 @@ class SamlHandler(BaseHandler):
 
             return None
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            return await self._sso_handler.get_mxid_from_sso(
-                self._auth_provider_id,
-                remote_user_id,
-                user_agent,
-                ip_address,
-                saml_response_to_remapped_user_attributes,
-                grandfather_existing_users,
+        await self._sso_handler.complete_sso_login_request(
+            self.idp_id,
+            remote_user_id,
+            request,
+            client_redirect_url,
+            saml_response_to_remapped_user_attributes,
+            grandfather_existing_users,
+        )
+
+    def _remote_id_from_saml_response(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
+    ) -> str:
+        """Extract the unique remote id from a SAML2 AuthnResponse
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+        Returns:
+            remote user id
+
+        Raises:
+            MappingException if there was an error extracting the user id
+        """
+        # It's not obvious why we need to pass in the redirect URI to the mapping
+        # provider, but we do :/
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+            saml2_auth, client_redirect_url
+        )
+
+        if not remote_user_id:
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
             )
 
+        return remote_user_id
+
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()