diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index e6e71e9729..6ef2b0de02 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -13,15 +13,16 @@
# limitations under the License.
import logging
import re
-from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Set, Tuple
import attr
import saml2
import saml2.response
from saml2.client import Saml2Client
-from synapse.api.errors import SynapseError
+from synapse.api.errors import RedirectException
from synapse.config import ConfigError
+from synapse.config.saml2 import DEFAULT_USER_MAPPING_PROVIDER
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
@@ -32,6 +33,8 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
+from synapse.util import dict_merge
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -54,7 +57,50 @@ class Saml2SessionData:
class SamlHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+
+ # If support for legacy saml2_mapping_providers is dropped then this
+ # is where the DefaultSamlMappingProvider should be loaded
+
+ self._user_mapping_provider = hs.get_saml2_user_mapping_provider()
+
+ # At this point either a module will have registered user mapping provider
+ # callbacks or the default will have been registered.
+ assert self._user_mapping_provider.module_has_registered
+
+ # Merge the required and optional saml_attributes registered by the mapping
+ # provider with the base sp config. NOTE: If there are conflicts then the
+ # module's expected attributes are overwritten by the base sp_config. This is
+ # how it worked with legacy modules.
+ (
+ required_attributes,
+ optional_attributes,
+ ) = self._user_mapping_provider.get_saml_attributes()
+
+ # Required for backwards compatability
+ if hs.config.saml2_grandfathered_mxid_source_attribute:
+ optional_attributes.add(hs.config.saml2_grandfathered_mxid_source_attribute)
+
+ optional_attributes -= required_attributes
+
+ sp_config_dict = {
+ "service": {
+ "sp": {
+ "required_attributes": list(required_attributes),
+ "optional_attributes": list(optional_attributes),
+ }
+ },
+ }
+
+ # Merged this way around for backwards compatability
+ dict_merge(
+ merge_dict=hs.config.saml2.base_sp_config,
+ into_dict=sp_config_dict,
+ )
+
+ self.saml2_sp_config = saml2.config.SPConfig()
+ self.saml2_sp_config.load(sp_config_dict)
+
+ self._saml_client = Saml2Client(self.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
@@ -64,12 +110,6 @@ class SamlHandler(BaseHandler):
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
self._error_template = hs.config.sso_error_template
- # plugin to do custom mapping from saml response to mxid
- self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
- hs.config.saml2_user_mapping_provider_config,
- ModuleApi(hs, hs.get_auth_handler()),
- )
-
# identifier for the external_ids table
self.idp_id = "saml"
@@ -222,7 +262,9 @@ class SamlHandler(BaseHandler):
# first check if we're doing a UIA
if current_session and current_session.ui_auth_session_id:
try:
- remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ remote_user_id = await self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, None
+ )
except MappingException as e:
logger.exception("Failed to extract remote user id from SAML response")
self._sso_handler.render_error(request, "mapping_error", str(e))
@@ -273,7 +315,7 @@ class SamlHandler(BaseHandler):
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
- remote_user_id = self._remote_id_from_saml_response(
+ remote_user_id = await self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
@@ -286,7 +328,7 @@ class SamlHandler(BaseHandler):
This is backwards compatibility for abstraction for the SSO handler.
"""
# Call the mapping provider.
- result = self._user_mapping_provider.saml_response_to_user_attributes(
+ result = await self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, failures, client_redirect_url
)
# Remap some of the results.
@@ -331,35 +373,6 @@ class SamlHandler(BaseHandler):
grandfather_existing_users,
)
- def _remote_id_from_saml_response(
- self,
- saml2_auth: saml2.response.AuthnResponse,
- client_redirect_url: Optional[str],
- ) -> str:
- """Extract the unique remote id from a SAML2 AuthnResponse
-
- Args:
- saml2_auth: The parsed SAML2 response.
- client_redirect_url: The redirect URL passed in by the client.
- Returns:
- remote user id
-
- Raises:
- MappingException if there was an error extracting the user id
- """
- # It's not obvious why we need to pass in the redirect URI to the mapping
- # provider, but we do :/
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
- saml2_auth, client_redirect_url
- )
-
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
- return remote_user_id
-
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
@@ -398,6 +411,15 @@ class SamlConfig:
mxid_mapper = attr.ib()
+# The type definition for the user mapping provider callbacks
+GET_REMOTE_USER_ID_CALLBACK = Callable[
+ [saml2.response.AuthnResponse, Optional[str]], Awaitable[str]
+]
+SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK = Callable[
+ [saml2.response.AuthnResponse, int, str], Awaitable[Dict]
+]
+
+
class DefaultSamlMappingProvider:
__version__ = "0.0.1"
@@ -411,12 +433,19 @@ class DefaultSamlMappingProvider:
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
- self._grandfathered_mxid_source_attribute = (
- module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+ module_api.register_saml2_user_mapping_provider_callbacks(
+ get_remote_user_id=self.get_remote_user_id,
+ saml_response_to_user_attributes=self.saml_response_to_user_attributes,
+ saml_attributes=(
+ {"uid", self._mxid_source_attribute},
+ {"displayName", "email"},
+ ),
)
- def get_remote_user_id(
- self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+ async def get_remote_user_id(
+ self,
+ saml_response: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
) -> str:
"""Extracts the remote user id from the SAML response"""
try:
@@ -425,7 +454,7 @@ class DefaultSamlMappingProvider:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise MappingException("'uid' not in SAML2 response")
- def saml_response_to_user_attributes(
+ async def saml_response_to_user_attributes(
self,
saml_response: saml2.response.AuthnResponse,
failures: int,
@@ -454,8 +483,8 @@ class DefaultSamlMappingProvider:
"SAML2 response lacks a '%s' attestation",
self._mxid_source_attribute,
)
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ raise MappingException(
+ "%s not in SAML2 response" % (self._mxid_source_attribute,)
)
# Use the configured mapper for this mxid_source
@@ -501,8 +530,235 @@ class DefaultSamlMappingProvider:
return SamlConfig(mxid_source_attribute, mxid_mapper)
- @staticmethod
- def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
+
+def load_default_or_legacy_saml2_mapping_provider(hs: "HomeServer"):
+ """Wrapper that loads a saml2 mapping provider either from the default module or
+ configured using the legacy configuration. Legacy modules then have their callbacks
+ registered
+ """
+
+ if hs.config.saml2.saml2_user_mapping_provider_class is None:
+ # This should be an impossible position to be in
+ raise RuntimeError("No default saml2 user mapping provider is set")
+
+ module = hs.config.saml2.saml2_user_mapping_provider_class
+ config = hs.config.saml2.saml2_user_mapping_provider_config
+ api = hs.get_module_api()
+
+ mapping_provider = module(config, api)
+
+ # if we were loading the default provider, then it has already registered its callbacks!
+ # so we can stop here
+ if module == DEFAULT_USER_MAPPING_PROVIDER:
+ return
+
+ # The required hooks. If a custom module doesn't implement all of these then raise an error
+ required_mapping_provider_methods = {
+ "get_saml_attributes",
+ "saml_response_to_user_attributes",
+ "get_remote_user_id",
+ }
+ missing_methods = [
+ method
+ for method in required_mapping_provider_methods
+ if not hasattr(module, method)
+ ]
+ if missing_methods:
+ raise RuntimeError(
+ "Class specified by saml2_config."
+ " user_mapping_provider.module is missing required"
+ " methods: %s" % (", ".join(missing_methods),)
+ )
+
+ # New modules have to proactively register this instead of just the callback
+ saml_attributes = mapping_provider.get_saml_attributes(config)
+
+ mapping_provider_methods = {
+ "saml_response_to_user_attributes",
+ "get_remote_user_id",
+ }
+
+ # Methods that the module provides should be async, but this wasn't the case
+ # in the old module system, so we wrap them if needed
+ def async_wrapper(f: Callable) -> Callable[..., Awaitable]:
+ def run(*args, **kwargs):
+ return maybe_awaitable(f(*args, **kwargs))
+
+ return run
+
+ # Register the hooks through the module API.
+ hooks = {
+ hook: async_wrapper(getattr(mapping_provider, hook, None))
+ for hook in mapping_provider_methods
+ }
+
+ api.register_saml2_user_mapping_provider_callbacks(
+ saml_attributes=saml_attributes, **hooks
+ )
+
+
+class Saml2UserMappingProvider:
+ def __init__(self, hs: "HomeServer"):
+ """The SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ module_api: module api proxy
+ """
+ # self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ # self._mxid_mapper = parsed_config.mxid_mapper
+ self.get_remote_user_id_callback: Optional[GET_REMOTE_USER_ID_CALLBACK] = None
+ self.saml_response_to_user_attributes_callback: Optional[
+ SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK
+ ] = None
+ self.saml_attributes: Tuple[Set[str], Set[str]] = set(), set()
+ self.module_has_registered = False
+
+ def register_saml2_user_mapping_provider_callbacks(
+ self,
+ get_remote_user_id: GET_REMOTE_USER_ID_CALLBACK,
+ saml_response_to_user_attributes: SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK,
+ saml_attributes: Tuple[Set[str], Set[str]],
+ ):
+ """Called by modules to register callbacks and saml_attributes"""
+
+ # Only one module can register callbacks
+ if self.module_has_registered:
+ raise RuntimeError(
+ "Multiple modules have attempted to register as saml mapping providers"
+ )
+ self.module_has_registered = True
+
+ self.get_remote_user_id_callback = get_remote_user_id
+ self.saml_response_to_user_attributes_callback = (
+ saml_response_to_user_attributes
+ )
+ self.saml_attributes = saml_attributes
+
+ async def get_remote_user_id(
+ self,
+ saml_response: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extracts the remote user id from the SAML response
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client. This may
+ be None.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException: if there was an error extracting the user id
+ Any other exception: for backwards compatability
+ """
+
+ # If no module has registered callbacks then raise an error
+ if not self.module_has_registered:
+ raise RuntimeError("No Saml2 mapping provider has been registered")
+
+ assert self.get_remote_user_id_callback is not None
+
+ try:
+ result = await self.get_remote_user_id_callback(
+ saml_response, client_redirect_url
+ )
+ except MappingException:
+ # Mapping providers are allowed to issue a mapping exception
+ # if a remote user id cannot be generated.
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Something went wrong when calling custom module callback for get_remote_user_id: {e}"
+ )
+ # for compatablity with legacy modules, need to raise this exception as is:
+ raise e
+
+ # # If the module raises some other sort of exception then don't display that to the user
+ # raise MappingException(
+ # "Failed to extract remote user id from SAML response"
+ # )
+
+ if not isinstance(result, str):
+ logger.warning( # type: ignore[unreachable]
+ f"Wrong type returned by module callback for get_remote_user_id: {result}, expected str"
+ )
+ # Don't overshare to the user, as something has clearly gone wrong
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
+ )
+
+ return result
+
+ async def saml_response_to_user_attributes(
+ self,
+ saml_response: saml2.response.AuthnResponse,
+ failures: int,
+ client_redirect_url: str,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ client_redirect_url: where the client wants to redirect to
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ * emails (list[str]): Any emails for the user
+
+ Raises:
+ MappingException: if something goes wrong while processing the response
+ RedirectException: some mapping providers may raise this if they need
+ to redirect to an interstitial page.
+ Any other exception: for backwards compatability
+ """
+
+ # If no module has registered callbacks then raise an error
+ if not self.module_has_registered:
+ raise RuntimeError("No Saml2 mapping provider has been registered")
+
+ assert self.saml_response_to_user_attributes_callback is not None
+
+ try:
+ result = await self.saml_response_to_user_attributes_callback(
+ saml_response, failures, client_redirect_url
+ )
+ except (RedirectException, MappingException):
+ # Mapping providers are allowed to issue a redirect (e.g. to ask
+ # the user for more information) and can issue a mapping exception
+ # if a name cannot be generated.
+ raise
+ except Exception as e:
+ logger.warning(
+ f"Something went wrong when calling custom module callback for saml_response_to_user_attributes: {e}"
+ )
+ # for compatablity with legacy modules, need to raise this exception as is:
+ raise e
+
+ # # If the module raises some other sort of exception then don't display that to the user
+ # raise MappingException(
+ # "Unable to map from SAML2 response to user attributes"
+ # )
+
+ if not isinstance(result, dict):
+ logger.warning( # type: ignore[unreachable]
+ f"Wrong type returned by module callback for get_remote_user_id: {result}, expected dict"
+ )
+ # Don't overshare to the user, as something has clearly gone wrong
+ raise MappingException(
+ "Unable to map from SAML2 response to user attributes"
+ )
+
+ return result
+
+ def get_saml_attributes(self) -> Tuple[Set[str], Set[str]]:
"""Returns the required attributes of a SAML
Args:
@@ -514,4 +770,4 @@ class DefaultSamlMappingProvider:
second set consists of those attributes which can be used if
available, but are not necessary
"""
- return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
+ return self.saml_attributes
|