summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py55
1 files changed, 45 insertions, 10 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 72c109981b..96f2dd36ad 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import Tuple
+from typing import Optional, Tuple
 
 import attr
 import saml2
@@ -26,6 +26,7 @@ from synapse.config import ConfigError
 from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
 from synapse.module_api import ModuleApi
+from synapse.module_api.errors import RedirectException
 from synapse.types import (
     UserID,
     map_username_to_mxid_localpart,
@@ -43,11 +44,15 @@ class Saml2SessionData:
 
     # time the session was created, in milliseconds
     creation_time = attr.ib()
+    # The user interactive authentication session ID associated with this SAML
+    # session (or None if this SAML session is for an initial login).
+    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+        self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
@@ -76,12 +81,14 @@ class SamlHandler:
 
         self._error_html_content = hs.config.saml2_error_html_content
 
-    def handle_redirect_request(self, client_redirect_url):
+    def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
         """Handle an incoming request to /login/sso/redirect
 
         Args:
             client_redirect_url (bytes): the URL that we should redirect the
                 client to when everything is done
+            ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
             bytes: URL to redirect to
@@ -91,7 +98,9 @@ class SamlHandler:
         )
 
         now = self._clock.time_msec()
-        self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now)
+        self._outstanding_requests_dict[reqid] = Saml2SessionData(
+            creation_time=now, ui_auth_session_id=ui_auth_session_id,
+        )
 
         for key, value in info["headers"]:
             if key == "Location":
@@ -118,7 +127,12 @@ class SamlHandler:
         self.expire_sessions()
 
         try:
-            user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
+            user_id, current_session = await self._map_saml_response_to_user(
+                resp_bytes, relay_state
+            )
+        except RedirectException:
+            # Raise the exception as per the wishes of the SAML module response
+            raise
         except Exception as e:
             # If decoding the response or mapping it to a user failed, then log the
             # error and tell the user that something went wrong.
@@ -133,9 +147,28 @@ class SamlHandler:
             finish_request(request)
             return
 
-        self._auth_handler.complete_sso_login(user_id, request, relay_state)
+        # Complete the interactive auth session or the login.
+        if current_session and current_session.ui_auth_session_id:
+            await self._auth_handler.complete_sso_ui_auth(
+                user_id, current_session.ui_auth_session_id, request
+            )
+
+        else:
+            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(
+        self, resp_bytes: str, client_redirect_url: str
+    ) -> Tuple[str, Optional[Saml2SessionData]]:
+        """
+        Given a sample response, retrieve the cached session and user for it.
 
-    async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
+        Args:
+            resp_bytes: The SAML response.
+            client_redirect_url: The redirect URL passed in by the client.
+
+        Returns:
+             Tuple of the user ID and SAML session associated with this response.
+        """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -163,7 +196,9 @@ class SamlHandler:
 
         logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
 
-        self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
+        current_session = self._outstanding_requests_dict.pop(
+            saml2_auth.in_response_to, None
+        )
 
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
@@ -184,7 +219,7 @@ class SamlHandler:
             )
             if registered_user_id is not None:
                 logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id
+                return registered_user_id, current_session
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -209,7 +244,7 @@ class SamlHandler:
                     await self._datastore.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
-                    return registered_user_id
+                    return registered_user_id, current_session
 
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
@@ -256,7 +291,7 @@ class SamlHandler:
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
-            return registered_user_id
+            return registered_user_id, current_session
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime