diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 0082f85c26..107f97032b 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,6 +24,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.servlet import parse_string
+from synapse.module_api import ModuleApi
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import (
UserID,
@@ -59,7 +60,8 @@ class SamlHandler:
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
- hs.config.saml2_user_mapping_provider_config
+ hs.config.saml2_user_mapping_provider_config,
+ ModuleApi(hs, hs.get_auth_handler()),
)
# identifier for the external_ids table
@@ -112,10 +114,10 @@ class SamlHandler:
# the dict.
self.expire_sessions()
- user_id = await self._map_saml_response_to_user(resp_bytes)
+ user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
- async def _map_saml_response_to_user(self, resp_bytes):
+ async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -183,7 +185,7 @@ class SamlHandler:
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i
+ saml2_auth, i, client_redirect_url=client_redirect_url,
)
logger.debug(
@@ -216,6 +218,8 @@ class SamlHandler:
500, "Unable to generate a Matrix ID from the SAML response"
)
+ logger.info("Mapped SAML user to local part %s", localpart)
+
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayname
)
@@ -265,17 +269,21 @@ class SamlConfig(object):
class DefaultSamlMappingProvider(object):
__version__ = "0.0.1"
- def __init__(self, parsed_config: SamlConfig):
+ def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
"""The default SAML user mapping provider
Args:
parsed_config: Module configuration
+ module_api: module api proxy
"""
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
def saml_response_to_user_attributes(
- self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ self,
+ saml_response: saml2.response.AuthnResponse,
+ failures: int,
+ client_redirect_url: str,
) -> dict:
"""Maps some text from a SAML response to attributes of a new user
@@ -285,6 +293,8 @@ class DefaultSamlMappingProvider(object):
failures: How many times a call to this function with this
saml_response has resulted in a failure
+ client_redirect_url: where the client wants to redirect to
+
Returns:
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
|