diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 37ab42f050..34db10ffe4 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
from synapse.types import (
UserID,
- contains_invalid_mxid_characters,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
@@ -250,14 +249,26 @@ class SamlHandler(BaseHandler):
"Failed to extract remote user id from SAML response"
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
- self._auth_provider_id, remote_user_id,
+ async def saml_response_to_remapped_user_attributes(
+ failures: int,
+ ) -> UserAttributes:
+ """
+ Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ # Call the mapping provider.
+ result = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, failures, client_redirect_url
+ )
+ # Remap some of the results.
+ return UserAttributes(
+ localpart=result.get("mxid_localpart"),
+ display_name=result.get("displayname"),
+ emails=result.get("emails"),
)
- if previously_registered_user_id:
- return previously_registered_user_id
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -284,59 +295,13 @@ class SamlHandler(BaseHandler):
)
return registered_user_id
- # Map saml response to user attributes using the configured mapping provider
- for i in range(1000):
- attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i, client_redirect_url=client_redirect_url,
- )
-
- logger.debug(
- "Retrieved SAML attributes from user mapping provider: %s "
- "(attempt %d)",
- attribute_dict,
- i,
- )
-
- localpart = attribute_dict.get("mxid_localpart")
- if not localpart:
- raise MappingException(
- "Error parsing SAML2 response: SAML mapping provider plugin "
- "did not return a mxid_localpart value"
- )
-
- displayname = attribute_dict.get("displayname")
- emails = attribute_dict.get("emails", [])
-
- # Check if this mxid already exists
- if not await self.store.get_users_by_id_case_insensitive(
- UserID(localpart, self.server_name).to_string()
- ):
- # This mxid is free
- break
- else:
- # Unable to generate a username in 1000 iterations
- # Break and return error to the user
- raise MappingException(
- "Unable to generate a Matrix ID from the SAML response"
- )
-
- # Since the localpart is provided via a potentially untrusted module,
- # ensure the MXID is valid before registering.
- if contains_invalid_mxid_characters(localpart):
- raise MappingException("localpart is invalid: %s" % (localpart,))
-
- logger.debug("Mapped SAML user to local part %s", localpart)
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- bind_emails=emails,
- user_agent_ips=[(user_agent, ip_address)],
- )
-
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
)
- return registered_user_id
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
@@ -451,11 +416,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
- # a usable mxid
- localpart = base_mxid_localpart + (str(failures) if failures else "")
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
|