summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py91
1 files changed, 28 insertions, 63 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 37ab42f050..34db10ffe4 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
 from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
 from synapse.types import (
     UserID,
-    contains_invalid_mxid_characters,
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
@@ -250,14 +249,26 @@ class SamlHandler(BaseHandler):
                 "Failed to extract remote user id from SAML response"
             )
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            # first of all, check if we already have a mapping for this user
-            previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
-                self._auth_provider_id, remote_user_id,
+        async def saml_response_to_remapped_user_attributes(
+            failures: int,
+        ) -> UserAttributes:
+            """
+            Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            # Call the mapping provider.
+            result = self._user_mapping_provider.saml_response_to_user_attributes(
+                saml2_auth, failures, client_redirect_url
+            )
+            # Remap some of the results.
+            return UserAttributes(
+                localpart=result.get("mxid_localpart"),
+                display_name=result.get("displayname"),
+                emails=result.get("emails"),
             )
-            if previously_registered_user_id:
-                return previously_registered_user_id
 
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
             if (
@@ -284,59 +295,13 @@ class SamlHandler(BaseHandler):
                     )
                     return registered_user_id
 
-            # Map saml response to user attributes using the configured mapping provider
-            for i in range(1000):
-                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
-                    saml2_auth, i, client_redirect_url=client_redirect_url,
-                )
-
-                logger.debug(
-                    "Retrieved SAML attributes from user mapping provider: %s "
-                    "(attempt %d)",
-                    attribute_dict,
-                    i,
-                )
-
-                localpart = attribute_dict.get("mxid_localpart")
-                if not localpart:
-                    raise MappingException(
-                        "Error parsing SAML2 response: SAML mapping provider plugin "
-                        "did not return a mxid_localpart value"
-                    )
-
-                displayname = attribute_dict.get("displayname")
-                emails = attribute_dict.get("emails", [])
-
-                # Check if this mxid already exists
-                if not await self.store.get_users_by_id_case_insensitive(
-                    UserID(localpart, self.server_name).to_string()
-                ):
-                    # This mxid is free
-                    break
-            else:
-                # Unable to generate a username in 1000 iterations
-                # Break and return error to the user
-                raise MappingException(
-                    "Unable to generate a Matrix ID from the SAML response"
-                )
-
-            # Since the localpart is provided via a potentially untrusted module,
-            # ensure the MXID is valid before registering.
-            if contains_invalid_mxid_characters(localpart):
-                raise MappingException("localpart is invalid: %s" % (localpart,))
-
-            logger.debug("Mapped SAML user to local part %s", localpart)
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=displayname,
-                bind_emails=emails,
-                user_agent_ips=[(user_agent, ip_address)],
-            )
-
-            await self.store.record_user_external_id(
-                self._auth_provider_id, remote_user_id, registered_user_id
+            return await self._sso_handler.get_mxid_from_sso(
+                self._auth_provider_id,
+                remote_user_id,
+                user_agent,
+                ip_address,
+                saml_response_to_remapped_user_attributes,
             )
-            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
@@ -451,11 +416,11 @@ class DefaultSamlMappingProvider:
             )
 
         # Use the configured mapper for this mxid_source
-        base_mxid_localpart = self._mxid_mapper(mxid_source)
+        localpart = self._mxid_mapper(mxid_source)
 
         # Append suffix integer if last call to this function failed to produce
-        # a usable mxid
-        localpart = base_mxid_localpart + (str(failures) if failures else "")
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
 
         # Retrieve the display name from the saml response
         # If displayname is None, the mxid_localpart will be used instead