diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
+import logging
from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
- map_username_to_mxid_localpart,
- mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+ "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
self.saml2_enabled = True
- self.saml2_mxid_source_attribute = saml2_config.get(
- "mxid_source_attribute", "uid"
- )
-
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
- saml2_config_dict = self._default_saml_config_dict()
+ # user_mapping_provider may be None if the key is present but has no value
+ ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+ # Use the default user mapping provider if not set
+ ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+ # Ensure a config is present
+ ump_dict["config"] = ump_dict.get("config") or {}
+
+ if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+ # Load deprecated options for use by the default module
+ old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+ if old_mxid_source_attribute:
+ logger.warning(
+ "The config option saml2_config.mxid_source_attribute is deprecated. "
+ "Please use saml2_config.user_mapping_provider.config"
+ ".mxid_source_attribute instead."
+ )
+ ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+ old_mxid_mapping = saml2_config.get("mxid_mapping")
+ if old_mxid_mapping:
+ logger.warning(
+ "The config option saml2_config.mxid_mapping is deprecated. Please "
+ "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+ )
+ ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+ # Retrieve an instance of the module's class
+ # Pass the config dictionary to the module for processing
+ (
+ self.saml2_user_mapping_provider_class,
+ self.saml2_user_mapping_provider_config,
+ ) = load_module(ump_dict)
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ # Note parse_config() is already checked during the call to load_module
+ required_methods = [
+ "get_saml_attributes",
+ "saml_response_to_user_attributes",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(self.saml2_user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class specified by saml2_config."
+ "user_mapping_provider.module is missing required "
+ "methods: %s" % (", ".join(missing_methods),)
+ )
+
+ # Get the desired saml auth response attributes from the module
+ saml2_config_dict = self._default_saml_config_dict(
+ *self.saml2_user_mapping_provider_class.get_saml_attributes(
+ self.saml2_user_mapping_provider_config
+ )
+ )
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
)
@@ -103,22 +159,27 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m")
)
- mapping = saml2_config.get("mxid_mapping", "hexencode")
- try:
- self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
- except KeyError:
- raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
- def _default_saml_config_dict(self):
+ def _default_saml_config_dict(
+ self, required_attributes: set, optional_attributes: set
+ ):
+ """Generate a configuration dictionary with required and optional attributes that
+ will be needed to process new user registration
+
+ Args:
+ required_attributes: SAML auth response attributes that are
+ necessary to function
+ optional_attributes: SAML auth response attributes that can be used to add
+ additional information to Synapse user accounts, but are not required
+
+ Returns:
+ dict: A SAML configuration dictionary
+ """
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
- required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
- optional_attributes = {"displayName"}
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
#
#config_path: "%(config_dir_path)s/sp_conf.py"
- # the lifetime of a SAML session. This defines how long a user has to
+ # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
- # The SAML attribute (after mapping via the attribute maps) to use to derive
- # the Matrix ID from. 'uid' by default.
+ # An external module can be provided here as a custom solution to
+ # mapping attributes returned from a saml provider onto a matrix user.
#
- #mxid_source_attribute: displayName
-
- # The mapping system to use for mapping the saml attribute onto a matrix ID.
- # Options include:
- # * 'hexencode' (which maps unpermitted characters to '=xx')
- # * 'dotreplace' (which replaces unpermitted characters with '.').
- # The default is 'hexencode'.
- #
- #mxid_mapping: dotreplace
-
- # In previous versions of synapse, the mapping from SAML attribute to MXID was
- # always calculated dynamically rather than stored in a table. For backwards-
- # compatibility, we will look for user_ids matching such a pattern before
- # creating a new account.
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ #
+ #module: mapping_provider.SamlMappingProvider
+
+ # Custom configuration values for the module. Below options are
+ # intended for the built-in provider, they should be changed if
+ # using a custom module. This section will be passed as a Python
+ # dictionary to the module's `parse_config` method.
+ #
+ config:
+ # The SAML attribute (after mapping via the attribute maps) to use
+ # to derive the Matrix ID from. 'uid' by default.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_source_attribute option. If that is still
+ # defined, its value will be used instead.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a
+ # matrix ID.
+ #
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with
+ # '.').
+ # The default is 'hexencode'.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_mapping option. If that is still defined, its
+ # value will be used instead.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to
+ # MXID was always calculated dynamically rather than stored in a
+ # table. For backwards- compatibility, we will look for user_ids
+ # matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
- # backwards-compatibility lookup. Typically it should be 'uid', but if the
- # attribute maps are changed, it may be necessary to change it.
+ # backwards-compatibility lookup. Typically it should be 'uid', but if
+ # the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
@@ -241,23 +327,3 @@ class SAML2Config(Config):
""" % {
"config_dir_path": config_dir_path
}
-
-
-DOT_REPLACE_PATTERN = re.compile(
- ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
- username = username.lower()
- username = DOT_REPLACE_PATTERN.sub(".", username)
-
- # regular mxids aren't allowed to start with an underscore either
- username = re.sub("^_", "", username)
- return username
-
-
-MXID_MAPPER_MAP = {
- "hexencode": map_username_to_mxid_localpart,
- "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
+from typing import Tuple
import attr
import saml2
+import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ UserID,
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
+@attr.s
+class Saml2SessionData:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib()
+
+
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
- self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
- self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # plugin to do custom mapping from saml response to mxid
+ self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+ hs.config.saml2_user_mapping_provider_config
+ )
# identifier for the external_ids table
self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
- raise SynapseError(400, "uid not in SAML2 response")
-
- try:
- mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
- except KeyError:
- logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
- )
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
- )
+ raise SynapseError(400, "'uid' not in SAML2 response")
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- displayName = saml2_auth.ava.get("displayName", [None])[0]
-
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
)
return registered_user_id
- # figure out a new mxid for this user
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ # Map saml response to user attributes using the configured mapping provider
+ for i in range(1000):
+ attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, i
+ )
+
+ logger.debug(
+ "Retrieved SAML attributes from user mapping provider: %s "
+ "(attempt %d)",
+ attribute_dict,
+ i,
+ )
+
+ localpart = attribute_dict.get("mxid_localpart")
+ if not localpart:
+ logger.error(
+ "SAML mapping provider plugin did not return a "
+ "mxid_localpart object"
+ )
+ raise SynapseError(500, "Error parsing SAML2 response")
- suffix = 0
- while True:
- localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ displayname = attribute_dict.get("displayname")
+
+ # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
+ # This mxid is free
break
- suffix += 1
- logger.info("Allocating mxid for new user with localpart %s", localpart)
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise SynapseError(
+ 500, "Unable to generate a Matrix ID from the SAML response"
+ )
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayName
+ localpart=localpart, default_display_name=displayname
)
+
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
@@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
+
+
@attr.s
-class Saml2SessionData:
- """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+ mxid_source_attribute = attr.ib()
+ mxid_mapper = attr.ib()
- # time the session was created, in milliseconds
- creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+ __version__ = "0.0.1"
+
+ def __init__(self, parsed_config: SamlConfig):
+ """The default SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ """
+ self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ self._mxid_mapper = parsed_config.mxid_mapper
+
+ def saml_response_to_user_attributes(
+ self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ """
+ try:
+ mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
+ # Use the configured mapper for this mxid_source
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid
+ localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+ # Retrieve the display name from the saml response
+ # If displayname is None, the mxid_localpart will be used instead
+ displayname = saml_response.ava.get("displayName", [None])[0]
+
+ return {
+ "mxid_localpart": localpart,
+ "displayname": displayname,
+ }
+
+ @staticmethod
+ def parse_config(config: dict) -> SamlConfig:
+ """Parse the dict provided by the homeserver's config
+ Args:
+ config: A dictionary containing configuration options for this provider
+ Returns:
+ SamlConfig: A custom config object for this module
+ """
+ # Parse config options and use defaults where necessary
+ mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+ mapping_type = config.get("mxid_mapping", "hexencode")
+
+ # Retrieve the associating mapping function
+ try:
+ mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+ except KeyError:
+ raise ConfigError(
+ "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+ "mxid_mapping value" % (mapping_type,)
+ )
+
+ return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+ @staticmethod
+ def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ """Returns the required attributes of a SAML
+
+ Args:
+ config: A SamlConfig object containing configuration params for this provider
+
+ Returns:
+ tuple[set,set]: The first set equates to the saml auth response
+ attributes that are required for the module to function, whereas the
+ second set consists of those attributes which can be used if
+ available, but are not necessary
+ """
+ return {"uid", config.mxid_source_attribute}, {"displayName"}
|