summary refs log tree commit diff
diff options
context:
space:
mode:
authorAzrenbeth <7782548+Azrenbeth@users.noreply.github.com>2021-08-24 11:33:02 +0100
committerAzrenbeth <7782548+Azrenbeth@users.noreply.github.com>2021-08-24 14:38:22 +0100
commit2b3e4e856fd00aee36423c2b4dcdb5b43d3e213c (patch)
tree395f100d902c5f3b29a090b06c9ad1e51dfd74a3
parentPersist room hierarchy pagination sessions to the database. (#10613) (diff)
downloadsynapse-2b3e4e856fd00aee36423c2b4dcdb5b43d3e213c.tar.xz
Port the saml mapping providers to new module interface
-rw-r--r--docs/sample_config.yaml18
-rw-r--r--synapse/app/_base.py6
-rw-r--r--synapse/config/saml2.py110
-rw-r--r--synapse/handlers/saml.py358
-rw-r--r--synapse/module_api/__init__.py8
-rw-r--r--synapse/module_api/errors.py1
-rw-r--r--synapse/rest/synapse/client/saml2/metadata_resource.py2
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/util/__init__.py28
9 files changed, 389 insertions, 147 deletions
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 935841dbfa..6a3c558b7e 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1544,7 +1544,9 @@ saml2_config:
   #
   # Default values will be used for the 'entityid' and 'service' settings,
   # so it is not normally necessary to specify them unless you need to
-  # override them.
+  # override them. Note that setting 'service.sp.required_attributes' or
+  # 'service.sp.optional_attributes' here will override anything configured
+  # by a module that registers saml2 user mapping provider callbacks
   #
   sp_config:
     # Point this to the IdP's metadata. You must provide either a local
@@ -1622,18 +1624,14 @@ saml2_config:
   #
   #saml_session_lifetime: 5m
 
-  # An external module can be provided here as a custom solution to
-  # mapping attributes returned from a saml provider onto a matrix user.
+  # Setting for the default mapping provider which maps attributes returned
+  # from a saml provider onto a matrix user. Custom solutions can be used by
+  # adding a module that provides these features to the 'modules' config
+  # section, in which case the following section will be ignored.
   #
   user_mapping_provider:
-    # The custom module's class. Uncomment to use a custom module.
-    #
-    #module: mapping_provider.SamlMappingProvider
-
     # Custom configuration values for the module. Below options are
-    # intended for the built-in provider, they should be changed if
-    # using a custom module. This section will be passed as a Python
-    # dictionary to the module's `parse_config` method.
+    # intended for the built-in provider.
     #
     config:
       # The SAML attribute (after mapping via the attribute maps) to use
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 39e28aff9f..4299327501 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -40,6 +40,7 @@ from synapse.crypto import context_factory
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.saml import load_default_or_legacy_saml2_mapping_provider
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -372,6 +373,11 @@ async def start(hs: "HomeServer"):
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)
+    # 'module_has_registered' is true if a module calls 'register_saml2_user_mapping_provider_callbacks'
+    # Only one mapping provider can be set, so only load default (or legacy configured one) if this is
+    # still false
+    if not hs.get_saml2_user_mapping_provider().module_has_registered:
+        load_default_or_legacy_saml2_mapping_provider(hs)
 
     # If we've configured an expiry time for caches, start the background job now.
     setup_expire_lru_cache_entries(hs)
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 05e983625d..b60e945152 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -18,6 +18,7 @@ from typing import Any, List
 
 from synapse.config.sso import SsoAttributeRequirement
 from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util import dict_merge
 from synapse.util.module_loader import load_module, load_python_module
 
 from ._base import Config, ConfigError
@@ -33,34 +34,6 @@ LEGACY_USER_MAPPING_PROVIDER = (
 )
 
 
-def _dict_merge(merge_dict, into_dict):
-    """Do a deep merge of two dicts
-
-    Recursively merges `merge_dict` into `into_dict`:
-      * For keys where both `merge_dict` and `into_dict` have a dict value, the values
-        are recursively merged
-      * For all other keys, the values in `into_dict` (if any) are overwritten with
-        the value from `merge_dict`.
-
-    Args:
-        merge_dict (dict): dict to merge
-        into_dict (dict): target dict
-    """
-    for k, v in merge_dict.items():
-        if k not in into_dict:
-            into_dict[k] = v
-            continue
-
-        current_val = into_dict[k]
-
-        if isinstance(v, dict) and isinstance(current_val, dict):
-            _dict_merge(v, current_val)
-            continue
-
-        # otherwise we just overwrite
-        into_dict[k] = v
-
-
 class SAML2Config(Config):
     section = "saml2"
 
@@ -99,11 +72,15 @@ class SAML2Config(Config):
         ump_dict = saml2_config.get("user_mapping_provider") or {}
 
         # Use the default user mapping provider if not set
+        # NOTE this is the legacy way of using custom modules
+        # New style-modules should be placed in the 'modules:' config section
         ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
         if ump_dict.get("module") == LEGACY_USER_MAPPING_PROVIDER:
             ump_dict["module"] = DEFAULT_USER_MAPPING_PROVIDER
 
         # Ensure a config is present
+        # This is the config for the default mapping provider, or the legacy
+        # way of configuring a custom module
         ump_dict["config"] = ump_dict.get("config") or {}
 
         if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
@@ -132,59 +109,30 @@ class SAML2Config(Config):
             self.saml2_user_mapping_provider_config,
         ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
 
-        # Ensure loaded user mapping module has defined all necessary methods
-        # Note parse_config() is already checked during the call to load_module
-        required_methods = [
-            "get_saml_attributes",
-            "saml_response_to_user_attributes",
-            "get_remote_user_id",
-        ]
-        missing_methods = [
-            method
-            for method in required_methods
-            if not hasattr(self.saml2_user_mapping_provider_class, method)
-        ]
-        if missing_methods:
-            raise ConfigError(
-                "Class specified by saml2_config."
-                "user_mapping_provider.module is missing required "
-                "methods: %s" % (", ".join(missing_methods),)
-            )
-
-        # Get the desired saml auth response attributes from the module
-        saml2_config_dict = self._default_saml_config_dict(
-            *self.saml2_user_mapping_provider_class.get_saml_attributes(
-                self.saml2_user_mapping_provider_config
-            )
-        )
-        _dict_merge(
-            merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
+        # This is only the *base* config since a custom user mapping provider can change
+        # the values of 'service.sp.required_attributes' and 'service.sp.optional_attributes'
+        self.base_sp_config = self._default_sp_config_dict()
+        dict_merge(
+            merge_dict=saml2_config.get("sp_config", {}), into_dict=self.base_sp_config
         )
 
-        config_path = saml2_config.get("config_path", None)
-        if config_path is not None:
-            mod = load_python_module(config_path)
-            config = getattr(mod, "CONFIG", None)
-            if config is None:
+        sp_config_path = saml2_config.get("config_path", None)
+        if sp_config_path is not None:
+            mod = load_python_module(sp_config_path)
+            sp_config_from_file = getattr(mod, "CONFIG", None)
+            if sp_config_from_file is None:
                 raise ConfigError(
                     "Config path specified by saml2_config.config_path does not "
                     "have a CONFIG property."
                 )
-            _dict_merge(merge_dict=config, into_dict=saml2_config_dict)
-
-        import saml2.config
-
-        self.saml2_sp_config = saml2.config.SPConfig()
-        self.saml2_sp_config.load(saml2_config_dict)
+            dict_merge(merge_dict=sp_config_from_file, into_dict=self.base_sp_config)
 
         # session lifetime: in milliseconds
         self.saml2_session_lifetime = self.parse_duration(
             saml2_config.get("saml_session_lifetime", "15m")
         )
 
-    def _default_saml_config_dict(
-        self, required_attributes: set, optional_attributes: set
-    ):
+    def _default_sp_config_dict(self):
         """Generate a configuration dictionary with required and optional attributes that
         will be needed to process new user registration
 
@@ -203,10 +151,6 @@ class SAML2Config(Config):
         if public_baseurl is None:
             raise ConfigError("saml2_config requires a public_baseurl to be set")
 
-        if self.saml2_grandfathered_mxid_source_attribute:
-            optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
-        optional_attributes -= required_attributes
-
         metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
         response_url = public_baseurl + "_synapse/client/saml2/authn_response"
         return {
@@ -218,8 +162,6 @@ class SAML2Config(Config):
                             (response_url, saml2.BINDING_HTTP_POST)
                         ]
                     },
-                    "required_attributes": list(required_attributes),
-                    "optional_attributes": list(optional_attributes),
                     # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
                 }
             },
@@ -257,7 +199,9 @@ class SAML2Config(Config):
           #
           # Default values will be used for the 'entityid' and 'service' settings,
           # so it is not normally necessary to specify them unless you need to
-          # override them.
+          # override them. Note that setting 'service.sp.required_attributes' or
+          # 'service.sp.optional_attributes' here will override anything configured
+          # by a module that registers saml2 user mapping provider callbacks
           #
           sp_config:
             # Point this to the IdP's metadata. You must provide either a local
@@ -335,18 +279,14 @@ class SAML2Config(Config):
           #
           #saml_session_lifetime: 5m
 
-          # An external module can be provided here as a custom solution to
-          # mapping attributes returned from a saml provider onto a matrix user.
+          # Setting for the default mapping provider which maps attributes returned
+          # from a saml provider onto a matrix user. Custom solutions can be used by
+          # adding a module that provides these features to the 'modules' config
+          # section, in which case the following section will be ignored.
           #
           user_mapping_provider:
-            # The custom module's class. Uncomment to use a custom module.
-            #
-            #module: mapping_provider.SamlMappingProvider
-
             # Custom configuration values for the module. Below options are
-            # intended for the built-in provider, they should be changed if
-            # using a custom module. This section will be passed as a Python
-            # dictionary to the module's `parse_config` method.
+            # intended for the built-in provider.
             #
             config:
               # The SAML attribute (after mapping via the attribute maps) to use
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index e6e71e9729..6ef2b0de02 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -13,15 +13,16 @@
 # limitations under the License.
 import logging
 import re
-from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional, Set, Tuple
 
 import attr
 import saml2
 import saml2.response
 from saml2.client import Saml2Client
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import RedirectException
 from synapse.config import ConfigError
+from synapse.config.saml2 import DEFAULT_USER_MAPPING_PROVIDER
 from synapse.handlers._base import BaseHandler
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
@@ -32,6 +33,8 @@ from synapse.types import (
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
+from synapse.util import dict_merge
+from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
@@ -54,7 +57,50 @@ class Saml2SessionData:
 class SamlHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
-        self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+
+        # If support for legacy saml2_mapping_providers is dropped then this
+        # is where the DefaultSamlMappingProvider should be loaded
+
+        self._user_mapping_provider = hs.get_saml2_user_mapping_provider()
+
+        # At this point either a module will have registered user mapping provider
+        # callbacks or the default will have been registered.
+        assert self._user_mapping_provider.module_has_registered
+
+        # Merge the required and optional saml_attributes registered by the mapping
+        # provider with the base sp config. NOTE: If there are conflicts then the
+        # module's expected attributes are overwritten by the base sp_config. This is
+        # how it worked with legacy modules.
+        (
+            required_attributes,
+            optional_attributes,
+        ) = self._user_mapping_provider.get_saml_attributes()
+
+        # Required for backwards compatability
+        if hs.config.saml2_grandfathered_mxid_source_attribute:
+            optional_attributes.add(hs.config.saml2_grandfathered_mxid_source_attribute)
+
+        optional_attributes -= required_attributes
+
+        sp_config_dict = {
+            "service": {
+                "sp": {
+                    "required_attributes": list(required_attributes),
+                    "optional_attributes": list(optional_attributes),
+                }
+            },
+        }
+
+        # Merged this way around for backwards compatability
+        dict_merge(
+            merge_dict=hs.config.saml2.base_sp_config,
+            into_dict=sp_config_dict,
+        )
+
+        self.saml2_sp_config = saml2.config.SPConfig()
+        self.saml2_sp_config.load(sp_config_dict)
+
+        self._saml_client = Saml2Client(self.saml2_sp_config)
         self._saml_idp_entityid = hs.config.saml2_idp_entityid
 
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
@@ -64,12 +110,6 @@ class SamlHandler(BaseHandler):
         self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
         self._error_template = hs.config.sso_error_template
 
-        # plugin to do custom mapping from saml response to mxid
-        self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
-            hs.config.saml2_user_mapping_provider_config,
-            ModuleApi(hs, hs.get_auth_handler()),
-        )
-
         # identifier for the external_ids table
         self.idp_id = "saml"
 
@@ -222,7 +262,9 @@ class SamlHandler(BaseHandler):
         # first check if we're doing a UIA
         if current_session and current_session.ui_auth_session_id:
             try:
-                remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+                remote_user_id = await self._user_mapping_provider.get_remote_user_id(
+                    saml2_auth, None
+                )
             except MappingException as e:
                 logger.exception("Failed to extract remote user id from SAML response")
                 self._sso_handler.render_error(request, "mapping_error", str(e))
@@ -273,7 +315,7 @@ class SamlHandler(BaseHandler):
             RedirectException: some mapping providers may raise this if they need
                 to redirect to an interstitial page.
         """
-        remote_user_id = self._remote_id_from_saml_response(
+        remote_user_id = await self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
         )
 
@@ -286,7 +328,7 @@ class SamlHandler(BaseHandler):
             This is backwards compatibility for abstraction for the SSO handler.
             """
             # Call the mapping provider.
-            result = self._user_mapping_provider.saml_response_to_user_attributes(
+            result = await self._user_mapping_provider.saml_response_to_user_attributes(
                 saml2_auth, failures, client_redirect_url
             )
             # Remap some of the results.
@@ -331,35 +373,6 @@ class SamlHandler(BaseHandler):
             grandfather_existing_users,
         )
 
-    def _remote_id_from_saml_response(
-        self,
-        saml2_auth: saml2.response.AuthnResponse,
-        client_redirect_url: Optional[str],
-    ) -> str:
-        """Extract the unique remote id from a SAML2 AuthnResponse
-
-        Args:
-            saml2_auth: The parsed SAML2 response.
-            client_redirect_url: The redirect URL passed in by the client.
-        Returns:
-            remote user id
-
-        Raises:
-            MappingException if there was an error extracting the user id
-        """
-        # It's not obvious why we need to pass in the redirect URI to the mapping
-        # provider, but we do :/
-        remote_user_id = self._user_mapping_provider.get_remote_user_id(
-            saml2_auth, client_redirect_url
-        )
-
-        if not remote_user_id:
-            raise MappingException(
-                "Failed to extract remote user id from SAML response"
-            )
-
-        return remote_user_id
-
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
@@ -398,6 +411,15 @@ class SamlConfig:
     mxid_mapper = attr.ib()
 
 
+# The type definition for the user mapping provider callbacks
+GET_REMOTE_USER_ID_CALLBACK = Callable[
+    [saml2.response.AuthnResponse, Optional[str]], Awaitable[str]
+]
+SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK = Callable[
+    [saml2.response.AuthnResponse, int, str], Awaitable[Dict]
+]
+
+
 class DefaultSamlMappingProvider:
     __version__ = "0.0.1"
 
@@ -411,12 +433,19 @@ class DefaultSamlMappingProvider:
         self._mxid_source_attribute = parsed_config.mxid_source_attribute
         self._mxid_mapper = parsed_config.mxid_mapper
 
-        self._grandfathered_mxid_source_attribute = (
-            module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+        module_api.register_saml2_user_mapping_provider_callbacks(
+            get_remote_user_id=self.get_remote_user_id,
+            saml_response_to_user_attributes=self.saml_response_to_user_attributes,
+            saml_attributes=(
+                {"uid", self._mxid_source_attribute},
+                {"displayName", "email"},
+            ),
         )
 
-    def get_remote_user_id(
-        self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+    async def get_remote_user_id(
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
     ) -> str:
         """Extracts the remote user id from the SAML response"""
         try:
@@ -425,7 +454,7 @@ class DefaultSamlMappingProvider:
             logger.warning("SAML2 response lacks a 'uid' attestation")
             raise MappingException("'uid' not in SAML2 response")
 
-    def saml_response_to_user_attributes(
+    async def saml_response_to_user_attributes(
         self,
         saml_response: saml2.response.AuthnResponse,
         failures: int,
@@ -454,8 +483,8 @@ class DefaultSamlMappingProvider:
                 "SAML2 response lacks a '%s' attestation",
                 self._mxid_source_attribute,
             )
-            raise SynapseError(
-                400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+            raise MappingException(
+                "%s not in SAML2 response" % (self._mxid_source_attribute,)
             )
 
         # Use the configured mapper for this mxid_source
@@ -501,8 +530,235 @@ class DefaultSamlMappingProvider:
 
         return SamlConfig(mxid_source_attribute, mxid_mapper)
 
-    @staticmethod
-    def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
+
+def load_default_or_legacy_saml2_mapping_provider(hs: "HomeServer"):
+    """Wrapper that loads a saml2 mapping provider either from the default module or
+    configured using the legacy configuration. Legacy modules then have their callbacks
+    registered
+    """
+
+    if hs.config.saml2.saml2_user_mapping_provider_class is None:
+        # This should be an impossible position to be in
+        raise RuntimeError("No default saml2 user mapping provider is set")
+
+    module = hs.config.saml2.saml2_user_mapping_provider_class
+    config = hs.config.saml2.saml2_user_mapping_provider_config
+    api = hs.get_module_api()
+
+    mapping_provider = module(config, api)
+
+    # if we were loading the default provider, then it has already registered its callbacks!
+    # so we can stop here
+    if module == DEFAULT_USER_MAPPING_PROVIDER:
+        return
+
+    # The required hooks. If a custom module doesn't implement all of these then raise an error
+    required_mapping_provider_methods = {
+        "get_saml_attributes",
+        "saml_response_to_user_attributes",
+        "get_remote_user_id",
+    }
+    missing_methods = [
+        method
+        for method in required_mapping_provider_methods
+        if not hasattr(module, method)
+    ]
+    if missing_methods:
+        raise RuntimeError(
+            "Class specified by saml2_config."
+            " user_mapping_provider.module is missing required"
+            " methods: %s" % (", ".join(missing_methods),)
+        )
+
+    # New modules have to proactively register this instead of just the callback
+    saml_attributes = mapping_provider.get_saml_attributes(config)
+
+    mapping_provider_methods = {
+        "saml_response_to_user_attributes",
+        "get_remote_user_id",
+    }
+
+    # Methods that the module provides should be async, but this wasn't the case
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Callable) -> Callable[..., Awaitable]:
+        def run(*args, **kwargs):
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks = {
+        hook: async_wrapper(getattr(mapping_provider, hook, None))
+        for hook in mapping_provider_methods
+    }
+
+    api.register_saml2_user_mapping_provider_callbacks(
+        saml_attributes=saml_attributes, **hooks
+    )
+
+
+class Saml2UserMappingProvider:
+    def __init__(self, hs: "HomeServer"):
+        """The SAML user mapping provider
+
+        Args:
+            parsed_config: Module configuration
+            module_api: module api proxy
+        """
+        # self._mxid_source_attribute = parsed_config.mxid_source_attribute
+        # self._mxid_mapper = parsed_config.mxid_mapper
+        self.get_remote_user_id_callback: Optional[GET_REMOTE_USER_ID_CALLBACK] = None
+        self.saml_response_to_user_attributes_callback: Optional[
+            SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK
+        ] = None
+        self.saml_attributes: Tuple[Set[str], Set[str]] = set(), set()
+        self.module_has_registered = False
+
+    def register_saml2_user_mapping_provider_callbacks(
+        self,
+        get_remote_user_id: GET_REMOTE_USER_ID_CALLBACK,
+        saml_response_to_user_attributes: SAML_RESPONSE_TO_USER_ATTRIBUTES_CALLBACK,
+        saml_attributes: Tuple[Set[str], Set[str]],
+    ):
+        """Called by modules to register callbacks and saml_attributes"""
+
+        # Only one module can register callbacks
+        if self.module_has_registered:
+            raise RuntimeError(
+                "Multiple modules have attempted to register as saml mapping providers"
+            )
+        self.module_has_registered = True
+
+        self.get_remote_user_id_callback = get_remote_user_id
+        self.saml_response_to_user_attributes_callback = (
+            saml_response_to_user_attributes
+        )
+        self.saml_attributes = saml_attributes
+
+    async def get_remote_user_id(
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
+    ) -> str:
+        """Extracts the remote user id from the SAML response
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client. This may
+                be None.
+        Returns:
+            remote user id
+
+        Raises:
+            MappingException: if there was an error extracting the user id
+            Any other exception: for backwards compatability
+        """
+
+        # If no module has registered callbacks then raise an error
+        if not self.module_has_registered:
+            raise RuntimeError("No Saml2 mapping provider has been registered")
+
+        assert self.get_remote_user_id_callback is not None
+
+        try:
+            result = await self.get_remote_user_id_callback(
+                saml_response, client_redirect_url
+            )
+        except MappingException:
+            # Mapping providers are allowed to issue a mapping exception
+            # if a remote user id cannot be generated.
+            raise
+        except Exception as e:
+            logger.warning(
+                f"Something went wrong when calling custom module callback for get_remote_user_id: {e}"
+            )
+            # for compatablity with legacy modules, need to raise this exception as is:
+            raise e
+
+            # # If the module raises some other sort of exception then don't display that to the user
+            # raise MappingException(
+            #     "Failed to extract remote user id from SAML response"
+            # )
+
+        if not isinstance(result, str):
+            logger.warning(  # type: ignore[unreachable]
+                f"Wrong type returned by module callback for get_remote_user_id: {result}, expected str"
+            )
+            # Don't overshare to the user, as something has clearly gone wrong
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
+
+        return result
+
+    async def saml_response_to_user_attributes(
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        failures: int,
+        client_redirect_url: str,
+    ) -> dict:
+        """Maps some text from a SAML response to attributes of a new user
+
+        Args:
+            saml_response: A SAML auth response object
+
+            failures: How many times a call to this function with this
+                saml_response has resulted in a failure
+
+            client_redirect_url: where the client wants to redirect to
+
+        Returns:
+            dict: A dict containing new user attributes. Possible keys:
+                * mxid_localpart (str): Required. The localpart of the user's mxid
+                * displayname (str): The displayname of the user
+                * emails (list[str]): Any emails for the user
+
+        Raises:
+            MappingException: if something goes wrong while processing the response
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+            Any other exception: for backwards compatability
+        """
+
+        # If no module has registered callbacks then raise an error
+        if not self.module_has_registered:
+            raise RuntimeError("No Saml2 mapping provider has been registered")
+
+        assert self.saml_response_to_user_attributes_callback is not None
+
+        try:
+            result = await self.saml_response_to_user_attributes_callback(
+                saml_response, failures, client_redirect_url
+            )
+        except (RedirectException, MappingException):
+            # Mapping providers are allowed to issue a redirect (e.g. to ask
+            # the user for more information) and can issue a mapping exception
+            # if a name cannot be generated.
+            raise
+        except Exception as e:
+            logger.warning(
+                f"Something went wrong when calling custom module callback for saml_response_to_user_attributes: {e}"
+            )
+            # for compatablity with legacy modules, need to raise this exception as is:
+            raise e
+
+            # # If the module raises some other sort of exception then don't display that to the user
+            # raise MappingException(
+            #     "Unable to map from SAML2 response to user attributes"
+            # )
+
+        if not isinstance(result, dict):
+            logger.warning(  # type: ignore[unreachable]
+                f"Wrong type returned by module callback for get_remote_user_id: {result}, expected dict"
+            )
+            # Don't overshare to the user, as something has clearly gone wrong
+            raise MappingException(
+                "Unable to map from SAML2 response to user attributes"
+            )
+
+        return result
+
+    def get_saml_attributes(self) -> Tuple[Set[str], Set[str]]:
         """Returns the required attributes of a SAML
 
         Args:
@@ -514,4 +770,4 @@ class DefaultSamlMappingProvider:
                 second set consists of those attributes which can be used if
                 available, but are not necessary
         """
-        return {"uid", config.mxid_source_attribute}, {"displayName", "email"}
+        return self.saml_attributes
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index b11fa6393b..f337a0f65b 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -117,6 +117,7 @@ class ModuleApi:
         self._account_validity_handler = hs.get_account_validity_handler()
         self._third_party_event_rules = hs.get_third_party_event_rules()
         self._presence_router = hs.get_presence_router()
+        self._saml2_user_mapping_provider = hs.get_saml2_user_mapping_provider()
 
     #################################################################################
     # The following methods should only be called during the module's initialisation.
@@ -141,6 +142,13 @@ class ModuleApi:
         """Registers callbacks for presence router capabilities."""
         return self._presence_router.register_presence_router_callbacks
 
+    @property
+    def register_saml2_user_mapping_provider_callbacks(self):
+        """Registers callbacks for presence router capabilities."""
+        return (
+            self._saml2_user_mapping_provider.register_saml2_user_mapping_provider_callbacks
+        )
+
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
 
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index 98ea911a81..1560cd1c36 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -20,3 +20,4 @@ from synapse.api.errors import (  # noqa: F401
     SynapseError,
 )
 from synapse.config._base import ConfigError  # noqa: F401
+from synapse.handlers.sso import MappingException  # noqa: F401
diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index b37c7083dc..720666589b 100644
--- a/synapse/rest/synapse/client/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
@@ -25,7 +25,7 @@ class SAML2MetadataResource(Resource):
 
     def __init__(self, hs):
         Resource.__init__(self)
-        self.sp_config = hs.config.saml2_sp_config
+        self.sp_config = hs.get_saml_handler().saml2_sp_config
 
     def render_GET(self, request):
         metadata_xml = saml2.metadata.create_metadata_string(
diff --git a/synapse/server.py b/synapse/server.py
index de6517663e..4b44e2ea06 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -100,6 +100,7 @@ from synapse.handlers.room_list import RoomListHandler
 from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
 from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
 from synapse.handlers.room_summary import RoomSummaryHandler
+from synapse.handlers.saml import Saml2UserMappingProvider
 from synapse.handlers.search import SearchHandler
 from synapse.handlers.send_email import SendEmailHandler
 from synapse.handlers.set_password import SetPasswordHandler
@@ -730,6 +731,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return SamlHandler(self)
 
     @cache_in_self
+    def get_saml2_user_mapping_provider(self) -> "Saml2UserMappingProvider":
+        return Saml2UserMappingProvider(self)
+
+    @cache_in_self
     def get_oidc_handler(self) -> "OidcHandler":
         from synapse.handlers.oidc import OidcHandler
 
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index b69f562ca5..64c12c1d29 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -213,3 +213,31 @@ def re_word_boundary(r: str) -> str:
     # we can't use \b as it chokes on unicode. however \W seems to be okay
     # as shorthand for [^0-9A-Za-z_].
     return r"(^|\W)%s(\W|$)" % (r,)
+
+
+def dict_merge(merge_dict, into_dict):
+    """Do a deep merge of two dicts
+
+    Recursively merges `merge_dict` into `into_dict`:
+      * For keys where both `merge_dict` and `into_dict` have a dict value, the values
+        are recursively merged
+      * For all other keys, the values in `into_dict` (if any) are overwritten with
+        the value from `merge_dict`.
+
+    Args:
+        merge_dict (dict): dict to merge
+        into_dict (dict): target dict
+    """
+    for k, v in merge_dict.items():
+        if k not in into_dict:
+            into_dict[k] = v
+            continue
+
+        current_val = into_dict[k]
+
+        if isinstance(v, dict) and isinstance(current_val, dict):
+            dict_merge(v, current_val)
+            continue
+
+        # otherwise we just overwrite
+        into_dict[k] = v