diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index dc04b53f43..4741c82f61 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import Tuple
+from typing import Optional, Tuple
import attr
import saml2
@@ -44,11 +44,15 @@ class Saml2SessionData:
# time the session was created, in milliseconds
creation_time = attr.ib()
+ # The user interactive authentication session ID associated with this SAML
+ # session (or None if this SAML session is for an initial login).
+ ui_auth_session_id = attr.ib(type=Optional[str], default=None)
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+ self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -77,12 +81,14 @@ class SamlHandler:
self._error_html_content = hs.config.saml2_error_html_content
- def handle_redirect_request(self, client_redirect_url):
+ def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
"""Handle an incoming request to /login/sso/redirect
Args:
client_redirect_url (bytes): the URL that we should redirect the
client to when everything is done
+ ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+ None if this is a login).
Returns:
bytes: URL to redirect to
@@ -92,7 +98,9 @@ class SamlHandler:
)
now = self._clock.time_msec()
- self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now)
+ self._outstanding_requests_dict[reqid] = Saml2SessionData(
+ creation_time=now, ui_auth_session_id=ui_auth_session_id,
+ )
for key, value in info["headers"]:
if key == "Location":
@@ -119,7 +127,9 @@ class SamlHandler:
self.expire_sessions()
try:
- user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
+ user_id, current_session = await self._map_saml_response_to_user(
+ resp_bytes, relay_state
+ )
except RedirectException:
# Raise the exception as per the wishes of the SAML module response
raise
@@ -137,9 +147,28 @@ class SamlHandler:
finish_request(request)
return
- self._auth_handler.complete_sso_login(user_id, request, relay_state)
+ # Complete the interactive auth session or the login.
+ if current_session and current_session.ui_auth_session_id:
+ self._auth_handler.complete_sso_ui_auth(
+ user_id, current_session.ui_auth_session_id, request
+ )
+
+ else:
+ self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+ async def _map_saml_response_to_user(
+ self, resp_bytes: str, client_redirect_url: str
+ ) -> Tuple[str, Optional[Saml2SessionData]]:
+ """
+ Given a sample response, retrieve the cached session and user for it.
- async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
+ Args:
+ resp_bytes: The SAML response.
+ client_redirect_url: The redirect URL passed in by the client.
+
+ Returns:
+ Tuple of the user ID and SAML session associated with this response.
+ """
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -167,7 +196,9 @@ class SamlHandler:
logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
- self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
+ current_session = self._outstanding_requests_dict.pop(
+ saml2_auth.in_response_to, None
+ )
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
@@ -188,7 +219,7 @@ class SamlHandler:
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ return registered_user_id, current_session
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
@@ -213,7 +244,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
- return registered_user_id
+ return registered_user_id, current_session
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
@@ -260,7 +291,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
- return registered_user_id
+ return registered_user_id, current_session
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
|