diff options
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r-- | synapse/handlers/saml_handler.py | 198 |
1 files changed, 170 insertions, 28 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index cc9e6b9bd0..0082f85c26 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) @@ -37,11 +53,14 @@ class SamlHandler: self._datastore = hs.get_datastore() self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime - self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) - self._mxid_mapper = hs.config.saml2_mxid_mapper + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config + ) # identifier for the external_ids table self._auth_provider_id = "saml" @@ -118,22 +137,10 @@ class SamlHandler: remote_user_id = saml2_auth.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") - - try: - mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) + raise SynapseError(400, "'uid' not in SAML2 response") self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - displayName = saml2_auth.ava.get("displayName", [None])[0] - with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -173,22 +180,46 @@ class SamlHandler: ) return registered_user_id - # figure out a new mxid for this user - base_mxid_localpart = self._mxid_mapper(mxid_source) + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i + ) + + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + logger.error( + "SAML mapping provider plugin did not return a " + "mxid_localpart object" + ) + raise SynapseError(500, "Error parsing SAML2 response") - suffix = 0 - while True: - localpart = base_mxid_localpart + (str(suffix) if suffix else "") + displayname = attribute_dict.get("displayname") + + # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( UserID(localpart, self._hostname).to_string() ): + # This mxid is free break - suffix += 1 - logger.info("Allocating mxid for new user with localpart %s", localpart) + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayName + localpart=localpart, default_display_name=displayname ) + await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) @@ -205,9 +236,120 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} + + @attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() - # time the session was created, in milliseconds - creation_time = attr.ib() + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + def saml_response_to_user_attributes( + self, saml_response: saml2.response.AuthnResponse, failures: int = 0, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + return { + "mxid_localpart": localpart, + "displayname": displayname, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + tuple[set,set]: The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName"} |