summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py198
1 files changed, 170 insertions, 28 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
+from typing import Tuple
 
 import attr
 import saml2
+import saml2.response
 from saml2.client import Saml2Client
 
 from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
 from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+    UserID,
+    map_username_to_mxid_localpart,
+    mxid_localpart_allowed_characters,
+)
 from synapse.util.async_helpers import Linearizer
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s
+class Saml2SessionData:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib()
+
+
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
         self._datastore = hs.get_datastore()
         self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
-        self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
         )
-        self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+        # plugin to do custom mapping from saml response to mxid
+        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+            hs.config.saml2_user_mapping_provider_config
+        )
 
         # identifier for the external_ids table
         self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
             remote_user_id = saml2_auth.ava["uid"][0]
         except KeyError:
             logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "uid not in SAML2 response")
-
-        try:
-            mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
-        except KeyError:
-            logger.warning(
-                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
-            )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
-            )
+            raise SynapseError(400, "'uid' not in SAML2 response")
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
-        displayName = saml2_auth.ava.get("displayName", [None])[0]
-
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
                     )
                     return registered_user_id
 
-            # figure out a new mxid for this user
-            base_mxid_localpart = self._mxid_mapper(mxid_source)
+            # Map saml response to user attributes using the configured mapping provider
+            for i in range(1000):
+                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+                    saml2_auth, i
+                )
+
+                logger.debug(
+                    "Retrieved SAML attributes from user mapping provider: %s "
+                    "(attempt %d)",
+                    attribute_dict,
+                    i,
+                )
+
+                localpart = attribute_dict.get("mxid_localpart")
+                if not localpart:
+                    logger.error(
+                        "SAML mapping provider plugin did not return a "
+                        "mxid_localpart object"
+                    )
+                    raise SynapseError(500, "Error parsing SAML2 response")
 
-            suffix = 0
-            while True:
-                localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+                displayname = attribute_dict.get("displayname")
+
+                # Check if this mxid already exists
                 if not await self._datastore.get_users_by_id_case_insensitive(
                     UserID(localpart, self._hostname).to_string()
                 ):
+                    # This mxid is free
                     break
-                suffix += 1
-            logger.info("Allocating mxid for new user with localpart %s", localpart)
+            else:
+                # Unable to generate a username in 1000 iterations
+                # Break and return error to the user
+                raise SynapseError(
+                    500, "Unable to generate a Matrix ID from the SAML response"
+                )
 
             registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=displayName
+                localpart=localpart, default_display_name=displayname
             )
+
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
@@ -205,9 +236,120 @@ class SamlHandler:
             del self._outstanding_requests_dict[reqid]
 
 
+DOT_REPLACE_PATTERN = re.compile(
+    ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+    username = username.lower()
+    username = DOT_REPLACE_PATTERN.sub(".", username)
+
+    # regular mxids aren't allowed to start with an underscore either
+    username = re.sub("^_", "", username)
+    return username
+
+
+MXID_MAPPER_MAP = {
+    "hexencode": map_username_to_mxid_localpart,
+    "dotreplace": dot_replace_for_mxid,
+}
+
+
 @attr.s
-class Saml2SessionData:
-    """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+    mxid_source_attribute = attr.ib()
+    mxid_mapper = attr.ib()
 
-    # time the session was created, in milliseconds
-    creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+    __version__ = "0.0.1"
+
+    def __init__(self, parsed_config: SamlConfig):
+        """The default SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+        """
+        self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        self._mxid_mapper = parsed_config.mxid_mapper
+
+    def saml_response_to_user_attributes(
+        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+        """
+        try:
+            mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+        except KeyError:
+            logger.warning(
+                "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+            )
+            raise SynapseError(
+                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            )
+
+        # Use the configured mapper for this mxid_source
+        base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid
+        localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+        # Retrieve the display name from the saml response
+        # If displayname is None, the mxid_localpart will be used instead
+        displayname = saml_response.ava.get("displayName", [None])[0]
+
+        return {
+            "mxid_localpart": localpart,
+            "displayname": displayname,
+        }
+
+    @staticmethod
+    def parse_config(config: dict) -> SamlConfig:
+        """Parse the dict provided by the homeserver's config
+        Args:
+            config: A dictionary containing configuration options for this provider
+        Returns:
+            SamlConfig: A custom config object for this module
+        """
+        # Parse config options and use defaults where necessary
+        mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+        mapping_type = config.get("mxid_mapping", "hexencode")
+
+        # Retrieve the associating mapping function
+        try:
+            mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+        except KeyError:
+            raise ConfigError(
+                "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+                "mxid_mapping value" % (mapping_type,)
+            )
+
+        return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+    @staticmethod
+    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+        """Returns the required attributes of a SAML
+
+        Args:
+            config: A SamlConfig object containing configuration params for this provider
+
+        Returns:
+            tuple[set,set]: The first set equates to the saml auth response
+                attributes that are required for the module to function, whereas the
+                second set consists of those attributes which can be used if
+                available, but are not necessary
+        """
+        return {"uid", config.mxid_source_attribute}, {"displayName"}