diff options
author | Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> | 2021-08-24 11:33:02 +0100 |
---|---|---|
committer | Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> | 2021-08-24 14:38:22 +0100 |
commit | 2b3e4e856fd00aee36423c2b4dcdb5b43d3e213c (patch) | |
tree | 395f100d902c5f3b29a090b06c9ad1e51dfd74a3 | |
parent | Persist room hierarchy pagination sessions to the database. (#10613) (diff) | |
download | synapse-2b3e4e856fd00aee36423c2b4dcdb5b43d3e213c.tar.xz |
Port the saml mapping providers to new module interface
-rw-r--r-- | docs/sample_config.yaml | 18 | ||||
-rw-r--r-- | synapse/app/_base.py | 6 | ||||
-rw-r--r-- | synapse/config/saml2.py | 110 | ||||
-rw-r--r-- | synapse/handlers/saml.py | 358 | ||||
-rw-r--r-- | synapse/module_api/__init__.py | 8 | ||||
-rw-r--r-- | synapse/module_api/errors.py | 1 | ||||
-rw-r--r-- | synapse/rest/synapse/client/saml2/metadata_resource.py | 2 | ||||
-rw-r--r-- | synapse/server.py | 5 | ||||
-rw-r--r-- | synapse/util/__init__.py | 28 |
9 files changed, 389 insertions, 147 deletions
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 935841dbfa..6a3c558b7e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1544,7 +1544,9 @@ saml2_config: # # Default values will be used for the 'entityid' and 'service' settings, # so it is not normally necessary to specify them unless you need to - # override them. + # override them. Note that setting 'service.sp.required_attributes' or + # 'service.sp.optional_attributes' here will override anything configured + # by a module that registers saml2 user mapping provider callbacks # sp_config: # Point this to the IdP's metadata. You must provide either a local @@ -1622,18 +1624,14 @@ saml2_config: # #saml_session_lifetime: 5m - # An external module can be provided here as a custom solution to - # mapping attributes returned from a saml provider onto a matrix user. + # Setting for the default mapping provider which maps attributes returned + # from a saml provider onto a matrix user. Custom solutions can be used by + # adding a module that provides these features to the 'modules' config + # section, in which case the following section will be ignored. # user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # - #module: mapping_provider.SamlMappingProvider - # Custom configuration values for the module. Below options are - # intended for the built-in provider, they should be changed if - # using a custom module. This section will be passed as a Python - # dictionary to the module's `parse_config` method. + # intended for the built-in provider. # config: # The SAML attribute (after mapping via the attribute maps) to use diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 39e28aff9f..4299327501 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -40,6 +40,7 @@ from synapse.crypto import context_factory from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules +from synapse.handlers.saml import load_default_or_legacy_saml2_mapping_provider from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -372,6 +373,11 @@ async def start(hs: "HomeServer"): load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) load_legacy_presence_router(hs) + # 'module_has_registered' is true if a module calls 'register_saml2_user_mapping_provider_callbacks' + # Only one mapping provider can be set, so only load default (or legacy configured one) if this is + # still false + if not hs.get_saml2_user_mapping_provider().module_has_registered: + load_default_or_legacy_saml2_mapping_provider(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 05e983625d..b60e945152 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -18,6 +18,7 @@ from typing import Any, List from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements +from synapse.util import dict_merge from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError @@ -33,34 +34,6 @@ LEGACY_USER_MAPPING_PROVIDER = ( ) -def _dict_merge(merge_dict, into_dict): - """Do a deep merge of two dicts - - Recursively merges `merge_dict` into `into_dict`: - * For keys where both `merge_dict` and `into_dict` have a dict value, the values - are recursively merged - * For all other keys, the values in `into_dict` (if any) are overwritten with - the value from `merge_dict`. - - Args: - merge_dict (dict): dict to merge - into_dict (dict): target dict - """ - for k, v in merge_dict.items(): - if k not in into_dict: - into_dict[k] = v - continue - - current_val = into_dict[k] - - if isinstance(v, dict) and isinstance(current_val, dict): - _dict_merge(v, current_val) - continue - - # otherwise we just overwrite - into_dict[k] = v - - class SAML2Config(Config): section = "saml2" @@ -99,11 +72,15 @@ class SAML2Config(Config): ump_dict = saml2_config.get("user_mapping_provider") or {} # Use the default user mapping provider if not set + # NOTE this is the legacy way of using custom modules + # New style-modules should be placed in the 'modules:' config section ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) if ump_dict.get("module") == LEGACY_USER_MAPPING_PROVIDER: ump_dict["module"] = DEFAULT_USER_MAPPING_PROVIDER # Ensure a config is present + # This is the config for the default mapping provider, or the legacy + # way of configuring a custom module ump_dict["config"] = ump_dict.get("config") or {} if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: @@ -132,59 +109,30 @@ class SAML2Config(Config): self.saml2_user_mapping_provider_config, ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider")) - # Ensure loaded user mapping module has defined all necessary methods - # Note parse_config() is already checked during the call to load_module - required_methods = [ - "get_saml_attributes", - "saml_response_to_user_attributes", - "get_remote_user_id", - ] - missing_methods = [ - method - for method in required_methods - if not hasattr(self.saml2_user_mapping_provider_class, method) - ] - if missing_methods: - raise ConfigError( - "Class specified by saml2_config." - "user_mapping_provider.module is missing required " - "methods: %s" % (", ".join(missing_methods),) - ) - - # Get the desired saml auth response attributes from the module - saml2_config_dict = self._default_saml_config_dict( - *self.saml2_user_mapping_provider_class.get_saml_attributes( - self.saml2_user_mapping_provider_config - ) - ) - _dict_merge( - merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict + # This is only the *base* config since a custom user mapping provider can change + # the values of 'service.sp.required_attributes' and 'service.sp.optional_attributes' + self.base_sp_config = self._default_sp_config_dict() + dict_merge( + merge_dict=saml2_config.get("sp_config", {}), into_dict=self.base_sp_config ) - config_path = saml2_config.get("config_path", None) - if config_path is not None: - mod = load_python_module(config_path) - config = getattr(mod, "CONFIG", None) - if config is None: + sp_config_path = saml2_config.get("config_path", None) + if sp_config_path is not None: + mod = load_python_module(sp_config_path) + sp_config_from_file = getattr(mod, "CONFIG", None) + if sp_config_from_file is None: raise ConfigError( "Config path specified by saml2_config.config_path does not " "have a CONFIG property." ) - _dict_merge(merge_dict=config, into_dict=saml2_config_dict) - - import saml2.config - - self.saml2_sp_config = saml2.config.SPConfig() - self.saml2_sp_config.load(saml2_config_dict) + dict_merge(merge_dict=sp_config_from_file, into_dict=self.base_sp_config) # session lifetime: in milliseconds self.saml2_session_lifetime = self.parse_duration( saml2_config.get("saml_session_lifetime", "15m") ) - def _default_saml_config_dict( - self, required_attributes: set, optional_attributes: set - ): + def _default_sp_config_dict(self): """Generate a configuration dictionary with required and optional attributes that will be needed to process new user registration @@ -203,10 +151,6 @@ class SAML2Config(Config): if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") - if self.saml2_grandfathered_mxid_source_attribute: - optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) - optional_attributes -= required_attributes - metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml" response_url = public_baseurl + "_synapse/client/saml2/authn_response" return { @@ -218,8 +162,6 @@ class SAML2Config(Config): (response_url, saml2.BINDING_HTTP_POST) ] }, - "required_attributes": list(required_attributes), - "optional_attributes": list(optional_attributes), # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT, } }, @@ -257,7 +199,9 @@ class SAML2Config(Config): # # Default values will be used for the 'entityid' and 'service' settings, # so it is not normally necessary to specify them unless you need to - # override them. + # override them. Note that setting 'service.sp.required_attributes' or + # 'service.sp.optional_attributes' here will override anything configured + # by a module that registers saml2 user mapping provider callbacks # sp_config: # Point this to the IdP's metadata. You must provide either a local @@ -335,18 +279,14 @@ class SAML2Config(Config): # #saml_session_lifetime: 5m - # An external module can be provided here as a custom solution to - # mapping attributes returned from a saml provider onto a matrix user. + # Setting for the default mapping provider which maps attributes returned + # from a saml provider onto a matrix user. Custom solutions can be used by + # adding a module that provides these features to the 'modules' config + # section, in which case the following section will be ignored. # user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # - #module: mapping_provider.SamlMappingProvider - # Custom configuration values for the module. Below options are - # intended for the built-in provider, they should be changed if - # using a custom module. This section will be passed as a Python - # dictionary to the module's `parse_config` method. + # intended for the built-in provider. # config: # The SAML attribute (after mapping via the attribute maps) to use diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index e6e71e9729..6ef2b0de02 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -13,15 +13,16 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Set, Tuple import attr import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import SynapseError +from synapse.api.errors import RedirectException from synapse.config import ConfigError +from synapse.config.saml2 import DEFAULT_USER_MAPPING_PROVIDER from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string @@ -32,6 +33,8 @@ from synapse.types import ( map_username_to_mxid_localpart, mxid_localpart_allowed_characters, ) +from synapse.util import dict_merge +from synapse.util.async_helpers import maybe_awaitable from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: @@ -54,7 +57,50 @@ class Saml2SessionData: class SamlHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._saml_client = Saml2Client(hs.config.saml2_sp_config) + + # If support for legacy saml2_mapping_providers is dropped then this + # is where the DefaultSamlMappingProvider should be loaded + + self._user_mapping_provider = hs.get_saml2_user_mapping_provider() + + # At this point either a module will have registered user mapping provider + # callbacks or the default will have been registered. + assert self._user_mapping_provider.module_has_registered + + # Merge the required and optional saml_attributes registered by the mapping + # provider with the base sp config. NOTE: If there are conflicts then the + # module's expected attributes are overwritten by the base sp_config. This is + # how it worked with legacy modules. + ( + required_attributes, + optional_attributes, + ) = self._user_mapping_provider.get_saml_attributes() + + # Required for backwards compatability + if hs.config.saml2_grandfathered_mxid_source_attribute: + optional_attributes.add(hs.config.saml2_grandfathered_mxid_source_attribute) + + optional_attributes -= required_attributes + + sp_config_dict = { + "service": { + "sp": { + "required_attributes": list(required_attributes), + "optional_attributes": list(optional_attributes), + } + }, + } + + # Merged this way around for backwards compatability + dict_merge( + merge_dict=hs.config.saml2.base_sp_config, + into_dict=sp_config_dict, + ) + + self.saml2_sp_config = saml2.config.SPConfig() + self.saml2_sp_config.load(sp_config_dict) + + self._saml_client = Saml2Client(self.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2_idp_entityid self._saml2_session_lifetime = hs.config.saml2_session_lifetime @@ -64,12 +110,6 @@ class SamlHandler(BaseHandler): self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements self._error_template = hs.config.sso_error_template - # plugin to do custom mapping from saml response to mxid - self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( - hs.config.saml2_user_mapping_provider_config, - ModuleApi(hs, hs.get_auth_handler()), - ) - # identifier for the external_ids table self.idp_id = "saml" @@ -222,7 +262,9 @@ class SamlHandler(BaseHandler): # first check if we're doing a UIA if current_session and current_session.ui_auth_session_id: try: - remote_user_id = self._remote_id_from_saml_response(saml2_auth, None) + remote_user_id = await self._user_mapping_provider.get_remote_user_id( + saml2_auth, None + ) except MappingException as e: logger.exception("Failed to extract remote user id from SAML response") self._sso_handler.render_error(request, "mapping_error", str(e)) @@ -273,7 +315,7 @@ class SamlHandler(BaseHandler): RedirectException: some mapping providers may raise this if they need to redirect to an interstitial page. """ - remote_user_id = self._remote_id_from_saml_response( + remote_user_id = await self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) @@ -286,7 +328,7 @@ class SamlHandler(BaseHandler): This is backwards compatibility for abstraction for the SSO handler. """ # Call the mapping provider. - result = self._user_mapping_provider.saml_response_to_user_attributes( + result = await self._user_mapping_provider.saml_response_to_user_attributes( saml2_auth, failures, client_redirect_url ) # Remap some of the results. @@ -331,35 +373,6 @@ class SamlHandler(BaseHandler): grandfather_existing_users, ) - def _remote_id_from_saml_response( - self, - saml2_auth: saml2.response.AuthnResponse, - client_redirect_url: Optional[str], - ) -> str: - """Extract the unique remote id from a SAML2 AuthnResponse - - Args: - saml2_auth: The parsed SAML2 response. - client_redirect_url: The redirect URL passed in by the client. - Returns: - remote user id - - Raises: - MappingException if there was an error extracting the user id - """ - # It's not obvious why we need to pass in the redirect URI to the mapping - # provider, but we do :/ - remote_user_id = self._user_mapping_provider.get_remote_user_id( - saml2_auth, client_redirect_url - ) - - if not remote_user_id: - raise MappingException( - "Failed to extract remote user id from SAML response" - ) - - return remote_user_id - def expire_sessions(self): expire_before = self.clock.time_msec() - self._saml2_session_lifetime to_expire = set() @@ -398,6 +411,15 @@ class SamlConfig: mxid_mapper = attr.ib() +# The type definition for the user mapping provider callbacks +GET_REMOTE_USER_ID_CALLBACK = Callable[ + [saml2.response.AuthnResponse, Optional[str]], Awaitable[str] +] +SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK = Callable[ + [saml2.response.AuthnResponse, int, str], Awaitable[Dict] +] + + class DefaultSamlMappingProvider: __version__ = "0.0.1" @@ -411,12 +433,19 @@ class DefaultSamlMappingProvider: self._mxid_source_attribute = parsed_config.mxid_source_attribute self._mxid_mapper = parsed_config.mxid_mapper - self._grandfathered_mxid_source_attribute = ( - module_api._hs.config.saml2_grandfathered_mxid_source_attribute + module_api.register_saml2_user_mapping_provider_callbacks( + get_remote_user_id=self.get_remote_user_id, + saml_response_to_user_attributes=self.saml_response_to_user_attributes, + saml_attributes=( + {"uid", self._mxid_source_attribute}, + {"displayName", "email"}, + ), ) - def get_remote_user_id( - self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str + async def get_remote_user_id( + self, + saml_response: saml2.response.AuthnResponse, + client_redirect_url: Optional[str], ) -> str: """Extracts the remote user id from the SAML response""" try: @@ -425,7 +454,7 @@ class DefaultSamlMappingProvider: logger.warning("SAML2 response lacks a 'uid' attestation") raise MappingException("'uid' not in SAML2 response") - def saml_response_to_user_attributes( + async def saml_response_to_user_attributes( self, saml_response: saml2.response.AuthnResponse, failures: int, @@ -454,8 +483,8 @@ class DefaultSamlMappingProvider: "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + raise MappingException( + "%s not in SAML2 response" % (self._mxid_source_attribute,) ) # Use the configured mapper for this mxid_source @@ -501,8 +530,235 @@ class DefaultSamlMappingProvider: return SamlConfig(mxid_source_attribute, mxid_mapper) - @staticmethod - def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]: + +def load_default_or_legacy_saml2_mapping_provider(hs: "HomeServer"): + """Wrapper that loads a saml2 mapping provider either from the default module or + configured using the legacy configuration. Legacy modules then have their callbacks + registered + """ + + if hs.config.saml2.saml2_user_mapping_provider_class is None: + # This should be an impossible position to be in + raise RuntimeError("No default saml2 user mapping provider is set") + + module = hs.config.saml2.saml2_user_mapping_provider_class + config = hs.config.saml2.saml2_user_mapping_provider_config + api = hs.get_module_api() + + mapping_provider = module(config, api) + + # if we were loading the default provider, then it has already registered its callbacks! + # so we can stop here + if module == DEFAULT_USER_MAPPING_PROVIDER: + return + + # The required hooks. If a custom module doesn't implement all of these then raise an error + required_mapping_provider_methods = { + "get_saml_attributes", + "saml_response_to_user_attributes", + "get_remote_user_id", + } + missing_methods = [ + method + for method in required_mapping_provider_methods + if not hasattr(module, method) + ] + if missing_methods: + raise RuntimeError( + "Class specified by saml2_config." + " user_mapping_provider.module is missing required" + " methods: %s" % (", ".join(missing_methods),) + ) + + # New modules have to proactively register this instead of just the callback + saml_attributes = mapping_provider.get_saml_attributes(config) + + mapping_provider_methods = { + "saml_response_to_user_attributes", + "get_remote_user_id", + } + + # Methods that the module provides should be async, but this wasn't the case + # in the old module system, so we wrap them if needed + def async_wrapper(f: Callable) -> Callable[..., Awaitable]: + def run(*args, **kwargs): + return maybe_awaitable(f(*args, **kwargs)) + + return run + + # Register the hooks through the module API. + hooks = { + hook: async_wrapper(getattr(mapping_provider, hook, None)) + for hook in mapping_provider_methods + } + + api.register_saml2_user_mapping_provider_callbacks( + saml_attributes=saml_attributes, **hooks + ) + + +class Saml2UserMappingProvider: + def __init__(self, hs: "HomeServer"): + """The SAML user mapping provider + + Args: + parsed_config: Module configuration + module_api: module api proxy + """ + # self._mxid_source_attribute = parsed_config.mxid_source_attribute + # self._mxid_mapper = parsed_config.mxid_mapper + self.get_remote_user_id_callback: Optional[GET_REMOTE_USER_ID_CALLBACK] = None + self.saml_response_to_user_attributes_callback: Optional[ + SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK + ] = None + self.saml_attributes: Tuple[Set[str], Set[str]] = set(), set() + self.module_has_registered = False + + def register_saml2_user_mapping_provider_callbacks( + self, + get_remote_user_id: GET_REMOTE_USER_ID_CALLBACK, + saml_response_to_user_attributes: SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK, + saml_attributes: Tuple[Set[str], Set[str]], + ): + """Called by modules to register callbacks and saml_attributes""" + + # Only one module can register callbacks + if self.module_has_registered: + raise RuntimeError( + "Multiple modules have attempted to register as saml mapping providers" + ) + self.module_has_registered = True + + self.get_remote_user_id_callback = get_remote_user_id + self.saml_response_to_user_attributes_callback = ( + saml_response_to_user_attributes + ) + self.saml_attributes = saml_attributes + + async def get_remote_user_id( + self, + saml_response: saml2.response.AuthnResponse, + client_redirect_url: Optional[str], + ) -> str: + """Extracts the remote user id from the SAML response + + Args: + saml2_auth: The parsed SAML2 response. + client_redirect_url: The redirect URL passed in by the client. This may + be None. + Returns: + remote user id + + Raises: + MappingException: if there was an error extracting the user id + Any other exception: for backwards compatability + """ + + # If no module has registered callbacks then raise an error + if not self.module_has_registered: + raise RuntimeError("No Saml2 mapping provider has been registered") + + assert self.get_remote_user_id_callback is not None + + try: + result = await self.get_remote_user_id_callback( + saml_response, client_redirect_url + ) + except MappingException: + # Mapping providers are allowed to issue a mapping exception + # if a remote user id cannot be generated. + raise + except Exception as e: + logger.warning( + f"Something went wrong when calling custom module callback for get_remote_user_id: {e}" + ) + # for compatablity with legacy modules, need to raise this exception as is: + raise e + + # # If the module raises some other sort of exception then don't display that to the user + # raise MappingException( + # "Failed to extract remote user id from SAML response" + # ) + + if not isinstance(result, str): + logger.warning( # type: ignore[unreachable] + f"Wrong type returned by module callback for get_remote_user_id: {result}, expected str" + ) + # Don't overshare to the user, as something has clearly gone wrong + raise MappingException( + "Failed to extract remote user id from SAML response" + ) + + return result + + async def saml_response_to_user_attributes( + self, + saml_response: saml2.response.AuthnResponse, + failures: int, + client_redirect_url: str, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + client_redirect_url: where the client wants to redirect to + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + * emails (list[str]): Any emails for the user + + Raises: + MappingException: if something goes wrong while processing the response + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. + Any other exception: for backwards compatability + """ + + # If no module has registered callbacks then raise an error + if not self.module_has_registered: + raise RuntimeError("No Saml2 mapping provider has been registered") + + assert self.saml_response_to_user_attributes_callback is not None + + try: + result = await self.saml_response_to_user_attributes_callback( + saml_response, failures, client_redirect_url + ) + except (RedirectException, MappingException): + # Mapping providers are allowed to issue a redirect (e.g. to ask + # the user for more information) and can issue a mapping exception + # if a name cannot be generated. + raise + except Exception as e: + logger.warning( + f"Something went wrong when calling custom module callback for saml_response_to_user_attributes: {e}" + ) + # for compatablity with legacy modules, need to raise this exception as is: + raise e + + # # If the module raises some other sort of exception then don't display that to the user + # raise MappingException( + # "Unable to map from SAML2 response to user attributes" + # ) + + if not isinstance(result, dict): + logger.warning( # type: ignore[unreachable] + f"Wrong type returned by module callback for get_remote_user_id: {result}, expected dict" + ) + # Don't overshare to the user, as something has clearly gone wrong + raise MappingException( + "Unable to map from SAML2 response to user attributes" + ) + + return result + + def get_saml_attributes(self) -> Tuple[Set[str], Set[str]]: """Returns the required attributes of a SAML Args: @@ -514,4 +770,4 @@ class DefaultSamlMappingProvider: second set consists of those attributes which can be used if available, but are not necessary """ - return {"uid", config.mxid_source_attribute}, {"displayName", "email"} + return self.saml_attributes diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b11fa6393b..f337a0f65b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -117,6 +117,7 @@ class ModuleApi: self._account_validity_handler = hs.get_account_validity_handler() self._third_party_event_rules = hs.get_third_party_event_rules() self._presence_router = hs.get_presence_router() + self._saml2_user_mapping_provider = hs.get_saml2_user_mapping_provider() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -141,6 +142,13 @@ class ModuleApi: """Registers callbacks for presence router capabilities.""" return self._presence_router.register_presence_router_callbacks + @property + def register_saml2_user_mapping_provider_callbacks(self): + """Registers callbacks for presence router capabilities.""" + return ( + self._saml2_user_mapping_provider.register_saml2_user_mapping_provider_callbacks + ) + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 98ea911a81..1560cd1c36 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -20,3 +20,4 @@ from synapse.api.errors import ( # noqa: F401 SynapseError, ) from synapse.config._base import ConfigError # noqa: F401 +from synapse.handlers.sso import MappingException # noqa: F401 diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py index b37c7083dc..720666589b 100644 --- a/synapse/rest/synapse/client/saml2/metadata_resource.py +++ b/synapse/rest/synapse/client/saml2/metadata_resource.py @@ -25,7 +25,7 @@ class SAML2MetadataResource(Resource): def __init__(self, hs): Resource.__init__(self) - self.sp_config = hs.config.saml2_sp_config + self.sp_config = hs.get_saml_handler().saml2_sp_config def render_GET(self, request): metadata_xml = saml2.metadata.create_metadata_string( diff --git a/synapse/server.py b/synapse/server.py index de6517663e..4b44e2ea06 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -100,6 +100,7 @@ from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.room_summary import RoomSummaryHandler +from synapse.handlers.saml import Saml2UserMappingProvider from synapse.handlers.search import SearchHandler from synapse.handlers.send_email import SendEmailHandler from synapse.handlers.set_password import SetPasswordHandler @@ -730,6 +731,10 @@ class HomeServer(metaclass=abc.ABCMeta): return SamlHandler(self) @cache_in_self + def get_saml2_user_mapping_provider(self) -> "Saml2UserMappingProvider": + return Saml2UserMappingProvider(self) + + @cache_in_self def get_oidc_handler(self) -> "OidcHandler": from synapse.handlers.oidc import OidcHandler diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b69f562ca5..64c12c1d29 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -213,3 +213,31 @@ def re_word_boundary(r: str) -> str: # we can't use \b as it chokes on unicode. however \W seems to be okay # as shorthand for [^0-9A-Za-z_]. return r"(^|\W)%s(\W|$)" % (r,) + + +def dict_merge(merge_dict, into_dict): + """Do a deep merge of two dicts + + Recursively merges `merge_dict` into `into_dict`: + * For keys where both `merge_dict` and `into_dict` have a dict value, the values + are recursively merged + * For all other keys, the values in `into_dict` (if any) are overwritten with + the value from `merge_dict`. + + Args: + merge_dict (dict): dict to merge + into_dict (dict): target dict + """ + for k, v in merge_dict.items(): + if k not in into_dict: + into_dict[k] = v + continue + + current_val = into_dict[k] + + if isinstance(v, dict) and isinstance(current_val, dict): + dict_merge(v, current_val) + continue + + # otherwise we just overwrite + into_dict[k] = v |