diff options
Diffstat (limited to 'synapse/handlers/saml.py')
-rw-r--r-- | synapse/handlers/saml.py | 358 |
1 files changed, 307 insertions, 51 deletions
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index e6e71e9729..6ef2b0de02 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -13,15 +13,16 @@ # limitations under the License. import logging import re -from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Set, Tuple import attr import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import SynapseError +from synapse.api.errors import RedirectException from synapse.config import ConfigError +from synapse.config.saml2 import DEFAULT_USER_MAPPING_PROVIDER from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string @@ -32,6 +33,8 @@ from synapse.types import ( map_username_to_mxid_localpart, mxid_localpart_allowed_characters, ) +from synapse.util import dict_merge +from synapse.util.async_helpers import maybe_awaitable from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: @@ -54,7 +57,50 @@ class Saml2SessionData: class SamlHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._saml_client = Saml2Client(hs.config.saml2_sp_config) + + # If support for legacy saml2_mapping_providers is dropped then this + # is where the DefaultSamlMappingProvider should be loaded + + self._user_mapping_provider = hs.get_saml2_user_mapping_provider() + + # At this point either a module will have registered user mapping provider + # callbacks or the default will have been registered. + assert self._user_mapping_provider.module_has_registered + + # Merge the required and optional saml_attributes registered by the mapping + # provider with the base sp config. NOTE: If there are conflicts then the + # module's expected attributes are overwritten by the base sp_config. This is + # how it worked with legacy modules. + ( + required_attributes, + optional_attributes, + ) = self._user_mapping_provider.get_saml_attributes() + + # Required for backwards compatability + if hs.config.saml2_grandfathered_mxid_source_attribute: + optional_attributes.add(hs.config.saml2_grandfathered_mxid_source_attribute) + + optional_attributes -= required_attributes + + sp_config_dict = { + "service": { + "sp": { + "required_attributes": list(required_attributes), + "optional_attributes": list(optional_attributes), + } + }, + } + + # Merged this way around for backwards compatability + dict_merge( + merge_dict=hs.config.saml2.base_sp_config, + into_dict=sp_config_dict, + ) + + self.saml2_sp_config = saml2.config.SPConfig() + self.saml2_sp_config.load(sp_config_dict) + + self._saml_client = Saml2Client(self.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2_idp_entityid self._saml2_session_lifetime = hs.config.saml2_session_lifetime @@ -64,12 +110,6 @@ class SamlHandler(BaseHandler): self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements self._error_template = hs.config.sso_error_template - # plugin to do custom mapping from saml response to mxid - self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( - hs.config.saml2_user_mapping_provider_config, - ModuleApi(hs, hs.get_auth_handler()), - ) - # identifier for the external_ids table self.idp_id = "saml" @@ -222,7 +262,9 @@ class SamlHandler(BaseHandler): # first check if we're doing a UIA if current_session and current_session.ui_auth_session_id: try: - remote_user_id = self._remote_id_from_saml_response(saml2_auth, None) + remote_user_id = await self._user_mapping_provider.get_remote_user_id( + saml2_auth, None + ) except MappingException as e: logger.exception("Failed to extract remote user id from SAML response") self._sso_handler.render_error(request, "mapping_error", str(e)) @@ -273,7 +315,7 @@ class SamlHandler(BaseHandler): RedirectException: some mapping providers may raise this if they need to redirect to an interstitial page. """ - remote_user_id = self._remote_id_from_saml_response( + remote_user_id = await self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) @@ -286,7 +328,7 @@ class SamlHandler(BaseHandler): This is backwards compatibility for abstraction for the SSO handler. """ # Call the mapping provider. - result = self._user_mapping_provider.saml_response_to_user_attributes( + result = await self._user_mapping_provider.saml_response_to_user_attributes( saml2_auth, failures, client_redirect_url ) # Remap some of the results. @@ -331,35 +373,6 @@ class SamlHandler(BaseHandler): grandfather_existing_users, ) - def _remote_id_from_saml_response( - self, - saml2_auth: saml2.response.AuthnResponse, - client_redirect_url: Optional[str], - ) -> str: - """Extract the unique remote id from a SAML2 AuthnResponse - - Args: - saml2_auth: The parsed SAML2 response. - client_redirect_url: The redirect URL passed in by the client. - Returns: - remote user id - - Raises: - MappingException if there was an error extracting the user id - """ - # It's not obvious why we need to pass in the redirect URI to the mapping - # provider, but we do :/ - remote_user_id = self._user_mapping_provider.get_remote_user_id( - saml2_auth, client_redirect_url - ) - - if not remote_user_id: - raise MappingException( - "Failed to extract remote user id from SAML response" - ) - - return remote_user_id - def expire_sessions(self): expire_before = self.clock.time_msec() - self._saml2_session_lifetime to_expire = set() @@ -398,6 +411,15 @@ class SamlConfig: mxid_mapper = attr.ib() +# The type definition for the user mapping provider callbacks +GET_REMOTE_USER_ID_CALLBACK = Callable[ + [saml2.response.AuthnResponse, Optional[str]], Awaitable[str] +] +SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK = Callable[ + [saml2.response.AuthnResponse, int, str], Awaitable[Dict] +] + + class DefaultSamlMappingProvider: __version__ = "0.0.1" @@ -411,12 +433,19 @@ class DefaultSamlMappingProvider: self._mxid_source_attribute = parsed_config.mxid_source_attribute self._mxid_mapper = parsed_config.mxid_mapper - self._grandfathered_mxid_source_attribute = ( - module_api._hs.config.saml2_grandfathered_mxid_source_attribute + module_api.register_saml2_user_mapping_provider_callbacks( + get_remote_user_id=self.get_remote_user_id, + saml_response_to_user_attributes=self.saml_response_to_user_attributes, + saml_attributes=( + {"uid", self._mxid_source_attribute}, + {"displayName", "email"}, + ), ) - def get_remote_user_id( - self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str + async def get_remote_user_id( + self, + saml_response: saml2.response.AuthnResponse, + client_redirect_url: Optional[str], ) -> str: """Extracts the remote user id from the SAML response""" try: @@ -425,7 +454,7 @@ class DefaultSamlMappingProvider: logger.warning("SAML2 response lacks a 'uid' attestation") raise MappingException("'uid' not in SAML2 response") - def saml_response_to_user_attributes( + async def saml_response_to_user_attributes( self, saml_response: saml2.response.AuthnResponse, failures: int, @@ -454,8 +483,8 @@ class DefaultSamlMappingProvider: "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + raise MappingException( + "%s not in SAML2 response" % (self._mxid_source_attribute,) ) # Use the configured mapper for this mxid_source @@ -501,8 +530,235 @@ class DefaultSamlMappingProvider: return SamlConfig(mxid_source_attribute, mxid_mapper) - @staticmethod - def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]: + +def load_default_or_legacy_saml2_mapping_provider(hs: "HomeServer"): + """Wrapper that loads a saml2 mapping provider either from the default module or + configured using the legacy configuration. Legacy modules then have their callbacks + registered + """ + + if hs.config.saml2.saml2_user_mapping_provider_class is None: + # This should be an impossible position to be in + raise RuntimeError("No default saml2 user mapping provider is set") + + module = hs.config.saml2.saml2_user_mapping_provider_class + config = hs.config.saml2.saml2_user_mapping_provider_config + api = hs.get_module_api() + + mapping_provider = module(config, api) + + # if we were loading the default provider, then it has already registered its callbacks! + # so we can stop here + if module == DEFAULT_USER_MAPPING_PROVIDER: + return + + # The required hooks. If a custom module doesn't implement all of these then raise an error + required_mapping_provider_methods = { + "get_saml_attributes", + "saml_response_to_user_attributes", + "get_remote_user_id", + } + missing_methods = [ + method + for method in required_mapping_provider_methods + if not hasattr(module, method) + ] + if missing_methods: + raise RuntimeError( + "Class specified by saml2_config." + " user_mapping_provider.module is missing required" + " methods: %s" % (", ".join(missing_methods),) + ) + + # New modules have to proactively register this instead of just the callback + saml_attributes = mapping_provider.get_saml_attributes(config) + + mapping_provider_methods = { + "saml_response_to_user_attributes", + "get_remote_user_id", + } + + # Methods that the module provides should be async, but this wasn't the case + # in the old module system, so we wrap them if needed + def async_wrapper(f: Callable) -> Callable[..., Awaitable]: + def run(*args, **kwargs): + return maybe_awaitable(f(*args, **kwargs)) + + return run + + # Register the hooks through the module API. + hooks = { + hook: async_wrapper(getattr(mapping_provider, hook, None)) + for hook in mapping_provider_methods + } + + api.register_saml2_user_mapping_provider_callbacks( + saml_attributes=saml_attributes, **hooks + ) + + +class Saml2UserMappingProvider: + def __init__(self, hs: "HomeServer"): + """The SAML user mapping provider + + Args: + parsed_config: Module configuration + module_api: module api proxy + """ + # self._mxid_source_attribute = parsed_config.mxid_source_attribute + # self._mxid_mapper = parsed_config.mxid_mapper + self.get_remote_user_id_callback: Optional[GET_REMOTE_USER_ID_CALLBACK] = None + self.saml_response_to_user_attributes_callback: Optional[ + SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK + ] = None + self.saml_attributes: Tuple[Set[str], Set[str]] = set(), set() + self.module_has_registered = False + + def register_saml2_user_mapping_provider_callbacks( + self, + get_remote_user_id: GET_REMOTE_USER_ID_CALLBACK, + saml_response_to_user_attributes: SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK, + saml_attributes: Tuple[Set[str], Set[str]], + ): + """Called by modules to register callbacks and saml_attributes""" + + # Only one module can register callbacks + if self.module_has_registered: + raise RuntimeError( + "Multiple modules have attempted to register as saml mapping providers" + ) + self.module_has_registered = True + + self.get_remote_user_id_callback = get_remote_user_id + self.saml_response_to_user_attributes_callback = ( + saml_response_to_user_attributes + ) + self.saml_attributes = saml_attributes + + async def get_remote_user_id( + self, + saml_response: saml2.response.AuthnResponse, + client_redirect_url: Optional[str], + ) -> str: + """Extracts the remote user id from the SAML response + + Args: + saml2_auth: The parsed SAML2 response. + client_redirect_url: The redirect URL passed in by the client. This may + be None. + Returns: + remote user id + + Raises: + MappingException: if there was an error extracting the user id + Any other exception: for backwards compatability + """ + + # If no module has registered callbacks then raise an error + if not self.module_has_registered: + raise RuntimeError("No Saml2 mapping provider has been registered") + + assert self.get_remote_user_id_callback is not None + + try: + result = await self.get_remote_user_id_callback( + saml_response, client_redirect_url + ) + except MappingException: + # Mapping providers are allowed to issue a mapping exception + # if a remote user id cannot be generated. + raise + except Exception as e: + logger.warning( + f"Something went wrong when calling custom module callback for get_remote_user_id: {e}" + ) + # for compatablity with legacy modules, need to raise this exception as is: + raise e + + # # If the module raises some other sort of exception then don't display that to the user + # raise MappingException( + # "Failed to extract remote user id from SAML response" + # ) + + if not isinstance(result, str): + logger.warning( # type: ignore[unreachable] + f"Wrong type returned by module callback for get_remote_user_id: {result}, expected str" + ) + # Don't overshare to the user, as something has clearly gone wrong + raise MappingException( + "Failed to extract remote user id from SAML response" + ) + + return result + + async def saml_response_to_user_attributes( + self, + saml_response: saml2.response.AuthnResponse, + failures: int, + client_redirect_url: str, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + client_redirect_url: where the client wants to redirect to + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + * emails (list[str]): Any emails for the user + + Raises: + MappingException: if something goes wrong while processing the response + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. + Any other exception: for backwards compatability + """ + + # If no module has registered callbacks then raise an error + if not self.module_has_registered: + raise RuntimeError("No Saml2 mapping provider has been registered") + + assert self.saml_response_to_user_attributes_callback is not None + + try: + result = await self.saml_response_to_user_attributes_callback( + saml_response, failures, client_redirect_url + ) + except (RedirectException, MappingException): + # Mapping providers are allowed to issue a redirect (e.g. to ask + # the user for more information) and can issue a mapping exception + # if a name cannot be generated. + raise + except Exception as e: + logger.warning( + f"Something went wrong when calling custom module callback for saml_response_to_user_attributes: {e}" + ) + # for compatablity with legacy modules, need to raise this exception as is: + raise e + + # # If the module raises some other sort of exception then don't display that to the user + # raise MappingException( + # "Unable to map from SAML2 response to user attributes" + # ) + + if not isinstance(result, dict): + logger.warning( # type: ignore[unreachable] + f"Wrong type returned by module callback for get_remote_user_id: {result}, expected dict" + ) + # Don't overshare to the user, as something has clearly gone wrong + raise MappingException( + "Unable to map from SAML2 response to user attributes" + ) + + return result + + def get_saml_attributes(self) -> Tuple[Set[str], Set[str]]: """Returns the required attributes of a SAML Args: @@ -514,4 +770,4 @@ class DefaultSamlMappingProvider: second set consists of those attributes which can be used if available, but are not necessary """ - return {"uid", config.mxid_source_attribute}, {"displayName", "email"} + return self.saml_attributes |