summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/saml_handler.py')
-rw-r--r--synapse/handlers/saml_handler.py62
1 files changed, 49 insertions, 13 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 0082f85c26..7f411b53b9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,6 +24,7 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.http.servlet import parse_string
+from synapse.module_api import ModuleApi
 from synapse.rest.client.v1.login import SSOAuthHandler
 from synapse.types import (
     UserID,
@@ -31,6 +32,7 @@ from synapse.types import (
     mxid_localpart_allowed_characters,
 )
 from synapse.util.async_helpers import Linearizer
+from synapse.util.iterutils import chunk_seq
 
 logger = logging.getLogger(__name__)
 
@@ -59,7 +61,8 @@ class SamlHandler:
 
         # plugin to do custom mapping from saml response to mxid
         self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
-            hs.config.saml2_user_mapping_provider_config
+            hs.config.saml2_user_mapping_provider_config,
+            ModuleApi(hs, hs.get_auth_handler()),
         )
 
         # identifier for the external_ids table
@@ -112,10 +115,10 @@ class SamlHandler:
         # the dict.
         self.expire_sessions()
 
-        user_id = await self._map_saml_response_to_user(resp_bytes)
+        user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
         self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
 
-    async def _map_saml_response_to_user(self, resp_bytes):
+    async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -130,17 +133,28 @@ class SamlHandler:
             logger.warning("SAML2 response was not signed")
             raise SynapseError(400, "SAML2 response was not signed")
 
-        logger.info("SAML2 response: %s", saml2_auth.origxml)
-        logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+        logger.debug("SAML2 response: %s", saml2_auth.origxml)
+        for assertion in saml2_auth.assertions:
+            # kibana limits the length of a log field, whereas this is all rather
+            # useful, so split it up.
+            count = 0
+            for part in chunk_seq(str(assertion), 10000):
+                logger.info(
+                    "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
+                )
+                count += 1
 
-        try:
-            remote_user_id = saml2_auth.ava["uid"][0]
-        except KeyError:
-            logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+        logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+            saml2_auth, client_redirect_url
+        )
+
+        if not remote_user_id:
+            raise Exception("Failed to extract remote user id from SAML response")
+
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -183,7 +197,7 @@ class SamlHandler:
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
                 attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
-                    saml2_auth, i
+                    saml2_auth, i, client_redirect_url=client_redirect_url,
                 )
 
                 logger.debug(
@@ -216,6 +230,8 @@ class SamlHandler:
                     500, "Unable to generate a Matrix ID from the SAML response"
                 )
 
+            logger.info("Mapped SAML user to local part %s", localpart)
+
             registered_user_id = await self._registration_handler.register_user(
                 localpart=localpart, default_display_name=displayname
             )
@@ -265,17 +281,35 @@ class SamlConfig(object):
 class DefaultSamlMappingProvider(object):
     __version__ = "0.0.1"
 
-    def __init__(self, parsed_config: SamlConfig):
+    def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
         """The default SAML user mapping provider
 
         Args:
             parsed_config: Module configuration
+            module_api: module api proxy
         """
         self._mxid_source_attribute = parsed_config.mxid_source_attribute
         self._mxid_mapper = parsed_config.mxid_mapper
 
+        self._grandfathered_mxid_source_attribute = (
+            module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+        )
+
+    def get_remote_user_id(
+        self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+    ):
+        """Extracts the remote user id from the SAML response"""
+        try:
+            return saml_response.ava["uid"][0]
+        except KeyError:
+            logger.warning("SAML2 response lacks a 'uid' attestation")
+            raise SynapseError(400, "'uid' not in SAML2 response")
+
     def saml_response_to_user_attributes(
-        self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+        self,
+        saml_response: saml2.response.AuthnResponse,
+        failures: int,
+        client_redirect_url: str,
     ) -> dict:
         """Maps some text from a SAML response to attributes of a new user
 
@@ -285,6 +319,8 @@ class DefaultSamlMappingProvider(object):
             failures: How many times a call to this function with this
                 saml_response has resulted in a failure
 
+            client_redirect_url: where the client wants to redirect to
+
         Returns:
             dict: A dict containing new user attributes. Possible keys:
                 * mxid_localpart (str): Required. The localpart of the user's mxid