diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..5fa7ab3f8b 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
-from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -59,8 +58,6 @@ class SamlHandler(BaseHandler):
super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
@@ -81,9 +78,6 @@ class SamlHandler(BaseHandler):
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
- # a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
@@ -167,6 +161,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+ await self._handle_authn_response(request, saml2_auth, relay_state)
+
+ async def _handle_authn_response(
+ self,
+ request: SynapseRequest,
+ saml2_auth: saml2.response.AuthnResponse,
+ relay_state: str,
+ ) -> None:
+ """Handle an AuthnResponse, having parsed it from the request params
+
+ Assumes that the signature on the response object has been checked. Maps
+ the user onto an MXID, registering them if necessary, and returns a response
+ to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+ saml2_auth: the parsed AuthnResponse object
+ relay_state: the RelayState query param, which encodes the URI to rediret
+ back to
+ """
+
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
@@ -183,6 +200,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
+ # first check if we're doing a UIA
+ if current_session and current_session.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ except MappingException as e:
+ logger.exception("Failed to extract remote user id from SAML response")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id,
+ remote_user_id,
+ current_session.ui_auth_session_id,
+ request,
+ )
+
+ # otherwise, we're handling a login request.
+
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
@@ -192,63 +227,39 @@ class SamlHandler(BaseHandler):
)
return
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
# Call the mapper to register/login the user
try:
- user_id = await self._map_saml_response_to_user(
- saml2_auth, relay_state, user_agent, ip_address
- )
+ await self._complete_saml_login(saml2_auth, request, relay_state)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
- return
- # Complete the interactive auth session or the login.
- if current_session and current_session.ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, current_session.ui_auth_session_id, request
- )
-
- else:
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
-
- async def _map_saml_response_to_user(
+ async def _complete_saml_login(
self,
saml2_auth: saml2.response.AuthnResponse,
+ request: SynapseRequest,
client_redirect_url: str,
- user_agent: str,
- ip_address: str,
- ) -> str:
+ ) -> None:
"""
- Given a SAML response, retrieve the user ID for it and possibly register the user.
+ Given a SAML response, complete the login flow
+
+ Retrieves the remote user ID, registers the user if necessary, and serves
+ a redirect back to the client with a login-token.
Args:
saml2_auth: The parsed SAML2 response.
+ request: The request to respond to
client_redirect_url: The redirect URL passed in by the client.
- user_agent: The user agent of the client making the request.
- ip_address: The IP address of the client making the request.
-
- Returns:
- The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
-
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -294,16 +305,44 @@ class SamlHandler(BaseHandler):
return None
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
- remote_user_id,
- user_agent,
- ip_address,
- saml_response_to_remapped_user_attributes,
- grandfather_existing_users,
+ await self._sso_handler.complete_sso_login_request(
+ self._auth_provider_id,
+ remote_user_id,
+ request,
+ client_redirect_url,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
+ )
+
+ def _remote_id_from_saml_response(
+ self,
+ saml2_auth: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extract the unique remote id from a SAML2 AuthnResponse
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException if there was an error extracting the user id
+ """
+ # It's not obvious why we need to pass in the redirect URI to the mapping
+ # provider, but we do :/
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
)
+ return remote_user_id
+
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
|