summary refs log tree commit diff
path: root/synapse/handlers/saml.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/saml.py')
-rw-r--r--synapse/handlers/saml.py358
1 files changed, 307 insertions, 51 deletions
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index e6e71e9729..6ef2b0de02 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -13,15 +13,16 @@
 # limitations under the License.
 import logging
 import re
-from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Set, Tuple
 
 import attr
 import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import RedirectException
 from synapse.config import ConfigError
+from synapse.config.saml2 import DEFAULT_USER_MAPPING_PROVIDER
 from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
@@ -32,6 +33,8 @@ from synapse.types import (
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
+from synapse.util import dict_merge
+from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
@@ -54,7 +57,50 @@ class Saml2SessionData:
 class SamlHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+
+        # If support for legacy saml2_mapping_providers is dropped then this
+        # is where the DefaultSamlMappingProvider should be loaded
+
+        self._user_mapping_provider = hs.get_saml2_user_mapping_provider()
+
+        # At this point either a module will have registered user mapping provider
+        # callbacks or the default will have been registered.
+        assert self._user_mapping_provider.module_has_registered
+
+        # Merge the required and optional saml_attributes registered by the mapping
+        # provider with the base sp config. NOTE: If there are conflicts then the
+        # module's expected attributes are overwritten by the base sp_config. This is
+        # how it worked with legacy modules.
+        (
+            required_attributes,
+            optional_attributes,
+        ) = self._user_mapping_provider.get_saml_attributes()
+
+        # Required for backwards compatability
+        if hs.config.saml2_grandfathered_mxid_source_attribute:
+            optional_attributes.add(hs.config.saml2_grandfathered_mxid_source_attribute)
+
+        optional_attributes -= required_attributes
+
+        sp_config_dict = {
+            "service": {
+                "sp": {
+                    "required_attributes": list(required_attributes),
+                    "optional_attributes": list(optional_attributes),
+                }
+            },
+        }
+
+        # Merged this way around for backwards compatability
+        dict_merge(
+            merge_dict=hs.config.saml2.base_sp_config,
+            into_dict=sp_config_dict,
+        )
+
+        self.saml2_sp_config = saml2.config.SPConfig()
+        self.saml2_sp_config.load(sp_config_dict)
+
+        self._saml_client = Saml2Client(self.saml2_sp_config)
         self._saml_idp_entityid = hs.config.saml2_idp_entityid
 
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
@@ -64,12 +110,6 @@ class SamlHandler(BaseHandler):
         self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
         self._error_template = hs.config.sso_error_template
 
-        # plugin to do custom mapping from saml response to mxid
-        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
-            hs.config.saml2_user_mapping_provider_config,
-            ModuleApi(hs, hs.get_auth_handler()),
-        )
-
         # identifier for the external_ids table
         self.idp_id = "saml"
 
@@ -222,7 +262,9 @@ class SamlHandler(BaseHandler):
         # first check if we're doing a UIA
         if current_session and current_session.ui_auth_session_id:
             try:
-                remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+                remote_user_id = await self._user_mapping_provider.get_remote_user_id(
+                    saml2_auth, None
+                )
             except MappingException as e:
                 logger.exception("Failed to extract remote user id from SAML response")
                 self._sso_handler.render_error(request, "mapping_error", str(e))
@@ -273,7 +315,7 @@ class SamlHandler(BaseHandler):
             RedirectException: some mapping providers may raise this if they need
                 to redirect to an interstitial page.
         """
-        remote_user_id = self._remote_id_from_saml_response(
+        remote_user_id = await self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
 
@@ -286,7 +328,7 @@ class SamlHandler(BaseHandler):
             This is backwards compatibility for abstraction for the SSO handler.
             """
             # Call the mapping provider.
-            result = self._user_mapping_provider.saml_response_to_user_attributes(
+            result = await self._user_mapping_provider.saml_response_to_user_attributes(
                 saml2_auth, failures, client_redirect_url
             )
             # Remap some of the results.
@@ -331,35 +373,6 @@ class SamlHandler(BaseHandler):
             grandfather_existing_users,
         )
 
-    def _remote_id_from_saml_response(
-        self,
-        saml2_auth: saml2.response.AuthnResponse,
-        client_redirect_url: Optional[str],
-    ) -> str:
-        """Extract the unique remote id from a SAML2 AuthnResponse
-
-        Args:
-            saml2_auth: The parsed SAML2 response.
-            client_redirect_url: The redirect URL passed in by the client.
-        Returns:
-            remote user id
-
-        Raises:
-            MappingException if there was an error extracting the user id
-        """
-        # It's not obvious why we need to pass in the redirect URI to the mapping
-        # provider, but we do :/
-        remote_user_id = self._user_mapping_provider.get_remote_user_id(
-            saml2_auth, client_redirect_url
-        )
-
-        if not remote_user_id:
-            raise MappingException(
-                "Failed to extract remote user id from SAML response"
-            )
-
-        return remote_user_id
-
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
@@ -398,6 +411,15 @@ class SamlConfig:
     mxid_mapper = attr.ib()
 
 
+# The type definition for the user mapping provider callbacks
+GET_REMOTE_USER_ID_CALLBACK = Callable[
+    [saml2.response.AuthnResponse, Optional[str]], Awaitable[str]
+]
+SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK = Callable[
+    [saml2.response.AuthnResponse, int, str], Awaitable[Dict]
+]
+
+
 class DefaultSamlMappingProvider:
     __version__ = "0.0.1"
 
@@ -411,12 +433,19 @@ class DefaultSamlMappingProvider:
         self._mxid_source_attribute = parsed_config.mxid_source_attribute
         self._mxid_mapper = parsed_config.mxid_mapper
 
-        self._grandfathered_mxid_source_attribute = (
-            module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+        module_api.register_saml2_user_mapping_provider_callbacks(
+            get_remote_user_id=self.get_remote_user_id,
+            saml_response_to_user_attributes=self.saml_response_to_user_attributes,
+            saml_attributes=(
+                {"uid", self._mxid_source_attribute},
+                {"displayName", "email"},
+            ),
         )
 
-    def get_remote_user_id(
-        self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+    async def get_remote_user_id(
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
     ) -> str:
         """Extracts the remote user id from the SAML response"""
         try:
@@ -425,7 +454,7 @@ class DefaultSamlMappingProvider:
             logger.warning("SAML2 response lacks a 'uid' attestation")
             raise MappingException("'uid' not in SAML2 response")
 
-    def saml_response_to_user_attributes(
+    async def saml_response_to_user_attributes(
         self,
         saml_response: saml2.response.AuthnResponse,
         failures: int,
@@ -454,8 +483,8 @@ class DefaultSamlMappingProvider:
                 "SAML2 response lacks a '%s' attestation",
                 self._mxid_source_attribute,
             )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            raise MappingException(
+                "%s not in SAML2 response" % (self._mxid_source_attribute,)
             )
 
         # Use the configured mapper for this mxid_source
@@ -501,8 +530,235 @@ class DefaultSamlMappingProvider:
 
         return SamlConfig(mxid_source_attribute, mxid_mapper)
 
-    @staticmethod
-    def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
+
+def load_default_or_legacy_saml2_mapping_provider(hs: "HomeServer"):
+    """Wrapper that loads a saml2 mapping provider either from the default module or
+    configured using the legacy configuration. Legacy modules then have their callbacks
+    registered
+    """
+
+    if hs.config.saml2.saml2_user_mapping_provider_class is None:
+        # This should be an impossible position to be in
+        raise RuntimeError("No default saml2 user mapping provider is set")
+
+    module = hs.config.saml2.saml2_user_mapping_provider_class
+    config = hs.config.saml2.saml2_user_mapping_provider_config
+    api = hs.get_module_api()
+
+    mapping_provider = module(config, api)
+
+    # if we were loading the default provider, then it has already registered its callbacks!
+    # so we can stop here
+    if module == DEFAULT_USER_MAPPING_PROVIDER:
+        return
+
+    # The required hooks. If a custom module doesn't implement all of these then raise an error
+    required_mapping_provider_methods = {
+        "get_saml_attributes",
+        "saml_response_to_user_attributes",
+        "get_remote_user_id",
+    }
+    missing_methods = [
+        method
+        for method in required_mapping_provider_methods
+        if not hasattr(module, method)
+    ]
+    if missing_methods:
+        raise RuntimeError(
+            "Class specified by saml2_config."
+            " user_mapping_provider.module is missing required"
+            " methods: %s" % (", ".join(missing_methods),)
+        )
+
+    # New modules have to proactively register this instead of just the callback
+    saml_attributes = mapping_provider.get_saml_attributes(config)
+
+    mapping_provider_methods = {
+        "saml_response_to_user_attributes",
+        "get_remote_user_id",
+    }
+
+    # Methods that the module provides should be async, but this wasn't the case
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Callable) -> Callable[..., Awaitable]:
+        def run(*args, **kwargs):
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks = {
+        hook: async_wrapper(getattr(mapping_provider, hook, None))
+        for hook in mapping_provider_methods
+    }
+
+    api.register_saml2_user_mapping_provider_callbacks(
+        saml_attributes=saml_attributes, **hooks
+    )
+
+
+class Saml2UserMappingProvider:
+    def __init__(self, hs: "HomeServer"):
+        """The SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+            module_api: module api proxy
+        """
+        # self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        # self._mxid_mapper = parsed_config.mxid_mapper
+        self.get_remote_user_id_callback: Optional[GET_REMOTE_USER_ID_CALLBACK] = None
+        self.saml_response_to_user_attributes_callback: Optional[
+            SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK
+        ] = None
+        self.saml_attributes: Tuple[Set[str], Set[str]] = set(), set()
+        self.module_has_registered = False
+
+    def register_saml2_user_mapping_provider_callbacks(
+        self,
+        get_remote_user_id: GET_REMOTE_USER_ID_CALLBACK,
+        saml_response_to_user_attributes: SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK,
+        saml_attributes: Tuple[Set[str], Set[str]],
+    ):
+        """Called by modules to register callbacks and saml_attributes"""
+
+        # Only one module can register callbacks
+        if self.module_has_registered:
+            raise RuntimeError(
+                "Multiple modules have attempted to register as saml mapping providers"
+            )
+        self.module_has_registered = True
+
+        self.get_remote_user_id_callback = get_remote_user_id
+        self.saml_response_to_user_attributes_callback = (
+            saml_response_to_user_attributes
+        )
+        self.saml_attributes = saml_attributes
+
+    async def get_remote_user_id(
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
+    ) -> str:
+        """Extracts the remote user id from the SAML response
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client. This may
+                be None.
+        Returns:
+            remote user id
+
+        Raises:
+            MappingException: if there was an error extracting the user id
+            Any other exception: for backwards compatability
+        """
+
+        # If no module has registered callbacks then raise an error
+        if not self.module_has_registered:
+            raise RuntimeError("No Saml2 mapping provider has been registered")
+
+        assert self.get_remote_user_id_callback is not None
+
+        try:
+            result = await self.get_remote_user_id_callback(
+                saml_response, client_redirect_url
+            )
+        except MappingException:
+            # Mapping providers are allowed to issue a mapping exception
+            # if a remote user id cannot be generated.
+            raise
+        except Exception as e:
+            logger.warning(
+                f"Something went wrong when calling custom module callback for get_remote_user_id: {e}"
+            )
+            # for compatablity with legacy modules, need to raise this exception as is:
+            raise e
+
+            # # If the module raises some other sort of exception then don't display that to the user
+            # raise MappingException(
+            #     "Failed to extract remote user id from SAML response"
+            # )
+
+        if not isinstance(result, str):
+            logger.warning(  # type: ignore[unreachable]
+                f"Wrong type returned by module callback for get_remote_user_id: {result}, expected str"
+            )
+            # Don't overshare to the user, as something has clearly gone wrong
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
+
+        return result
+
+    async def saml_response_to_user_attributes(
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        failures: int,
+        client_redirect_url: str,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+            client_redirect_url: where the client wants to redirect to
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+                * emails (list[str]): Any emails for the user
+
+        Raises:
+            MappingException: if something goes wrong while processing the response
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+            Any other exception: for backwards compatability
+        """
+
+        # If no module has registered callbacks then raise an error
+        if not self.module_has_registered:
+            raise RuntimeError("No Saml2 mapping provider has been registered")
+
+        assert self.saml_response_to_user_attributes_callback is not None
+
+        try:
+            result = await self.saml_response_to_user_attributes_callback(
+                saml_response, failures, client_redirect_url
+            )
+        except (RedirectException, MappingException):
+            # Mapping providers are allowed to issue a redirect (e.g. to ask
+            # the user for more information) and can issue a mapping exception
+            # if a name cannot be generated.
+            raise
+        except Exception as e:
+            logger.warning(
+                f"Something went wrong when calling custom module callback for saml_response_to_user_attributes: {e}"
+            )
+            # for compatablity with legacy modules, need to raise this exception as is:
+            raise e
+
+            # # If the module raises some other sort of exception then don't display that to the user
+            # raise MappingException(
+            #     "Unable to map from SAML2 response to user attributes"
+            # )
+
+        if not isinstance(result, dict):
+            logger.warning(  # type: ignore[unreachable]
+                f"Wrong type returned by module callback for get_remote_user_id: {result}, expected dict"
+            )
+            # Don't overshare to the user, as something has clearly gone wrong
+            raise MappingException(
+                "Unable to map from SAML2 response to user attributes"
+            )
+
+        return result
+
+    def get_saml_attributes(self) -> Tuple[Set[str], Set[str]]:
         """Returns the required attributes of a SAML
 
         Args:
@@ -514,4 +770,4 @@ class DefaultSamlMappingProvider:
                 second set consists of those attributes which can be used if
                 available, but are not necessary
         """
-        return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
+        return self.saml_attributes