summary refs log tree commit diff
path: root/synapse/handlers/saml_handler.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/saml_handler.py161
1 files changed, 51 insertions, 110 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 285c481a96..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
-    import synapse.server
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
-class MappingException(Exception):
-    """Used to catch errors when mapping the SAML2 response to a user."""
-
-
 @attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
@@ -57,17 +54,14 @@ class Saml2SessionData:
     ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
-class SamlHandler:
-    def __init__(self, hs: "synapse.server.HomeServer"):
-        self.hs = hs
+class SamlHandler(BaseHandler):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
-        self._auth = hs.get_auth()
+        self._saml_idp_entityid = hs.config.saml2_idp_entityid
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
-        self._clock = hs.get_clock()
-        self._datastore = hs.get_datastore()
-        self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +82,9 @@ class SamlHandler:
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         # a lock on the mappings
-        self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
-    def _render_error(
-        self, request, error: str, error_description: Optional[str] = None
-    ) -> None:
-        """Render the error template and respond to the request with it.
-
-        This is used to show errors to the user. The template of this page can
-        be found under `synapse/res/templates/sso_error.html`.
+        self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
 
-        Args:
-            request: The incoming request from the browser.
-                We'll respond with an HTML page describing the error.
-            error: A technical identifier for this error.
-            error_description: A human-readable description of the error.
-        """
-        html = self._error_template.render(
-            error=error, error_description=error_description
-        )
-        respond_with_html(request, 400, html)
+        self._sso_handler = hs.get_sso_handler()
 
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -124,13 +101,13 @@ class SamlHandler:
             URL to redirect to
         """
         reqid, info = self._saml_client.prepare_for_authenticate(
-            relay_state=client_redirect_url
+            entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
 
         # Since SAML sessions timeout it is useful to log when they were created.
         logger.info("Initiating a new SAML session: %s" % (reqid,))
 
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         self._outstanding_requests_dict[reqid] = Saml2SessionData(
             creation_time=now, ui_auth_session_id=ui_auth_session_id,
         )
@@ -171,12 +148,12 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsolicited_response", "Unexpected SAML2 login."
             )
             return
         except Exception as e:
-            self._render_error(
+            self._sso_handler.render_error(
                 request,
                 "invalid_response",
                 "Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +161,7 @@ class SamlHandler:
             return
 
         if saml2_auth.not_signed:
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsigned_respond", "SAML2 response was not signed."
             )
             return
@@ -210,15 +187,13 @@ class SamlHandler:
         # attributes.
         for requirement in self._saml2_attribute_requirements:
             if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._render_error(
+                self._sso_handler.render_error(
                     request, "unauthorised", "You are not authorised to log in here."
                 )
                 return
 
         # Pull out the user-agent and IP from the request.
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
+        user_agent = request.get_user_agent("")
         ip_address = self.hs.get_ip_from_request(request)
 
         # Call the mapper to register/login the user
@@ -228,7 +203,7 @@ class SamlHandler:
             )
         except MappingException as e:
             logger.exception("Could not map user")
-            self._render_error(request, "mapping_error", str(e))
+            self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
         # Complete the interactive auth session or the login.
@@ -274,20 +249,26 @@ class SamlHandler:
                 "Failed to extract remote user id from SAML response"
             )
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            # first of all, check if we already have a mapping for this user
-            logger.info(
-                "Looking for existing mapping for user %s:%s",
-                self._auth_provider_id,
-                remote_user_id,
+        async def saml_response_to_remapped_user_attributes(
+            failures: int,
+        ) -> UserAttributes:
+            """
+            Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            # Call the mapping provider.
+            result = self._user_mapping_provider.saml_response_to_user_attributes(
+                saml2_auth, failures, client_redirect_url
             )
-            registered_user_id = await self._datastore.get_user_by_external_id(
-                self._auth_provider_id, remote_user_id
+            # Remap some of the results.
+            return UserAttributes(
+                localpart=result.get("mxid_localpart"),
+                display_name=result.get("displayname"),
+                emails=result.get("emails", []),
             )
-            if registered_user_id is not None:
-                logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id
 
+        async def grandfather_existing_users() -> Optional[str]:
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
             if (
@@ -296,75 +277,35 @@ class SamlHandler:
             ):
                 attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
                 user_id = UserID(
-                    map_username_to_mxid_localpart(attrval), self._hostname
+                    map_username_to_mxid_localpart(attrval), self.server_name
                 ).to_string()
-                logger.info(
+
+                logger.debug(
                     "Looking for existing account based on mapped %s %s",
                     self._grandfathered_mxid_source_attribute,
                     user_id,
                 )
 
-                users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+                users = await self.store.get_users_by_id_case_insensitive(user_id)
                 if users:
                     registered_user_id = list(users.keys())[0]
                     logger.info("Grandfathering mapping to %s", registered_user_id)
-                    await self._datastore.record_user_external_id(
-                        self._auth_provider_id, remote_user_id, registered_user_id
-                    )
                     return registered_user_id
 
-            # Map saml response to user attributes using the configured mapping provider
-            for i in range(1000):
-                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
-                    saml2_auth, i, client_redirect_url=client_redirect_url,
-                )
-
-                logger.debug(
-                    "Retrieved SAML attributes from user mapping provider: %s "
-                    "(attempt %d)",
-                    attribute_dict,
-                    i,
-                )
-
-                localpart = attribute_dict.get("mxid_localpart")
-                if not localpart:
-                    raise MappingException(
-                        "Error parsing SAML2 response: SAML mapping provider plugin "
-                        "did not return a mxid_localpart value"
-                    )
-
-                displayname = attribute_dict.get("displayname")
-                emails = attribute_dict.get("emails", [])
-
-                # Check if this mxid already exists
-                if not await self._datastore.get_users_by_id_case_insensitive(
-                    UserID(localpart, self._hostname).to_string()
-                ):
-                    # This mxid is free
-                    break
-            else:
-                # Unable to generate a username in 1000 iterations
-                # Break and return error to the user
-                raise MappingException(
-                    "Unable to generate a Matrix ID from the SAML response"
-                )
+            return None
 
-            logger.info("Mapped SAML user to local part %s", localpart)
-
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=displayname,
-                bind_emails=emails,
-                user_agent_ips=(user_agent, ip_address),
-            )
-
-            await self._datastore.record_user_external_id(
-                self._auth_provider_id, remote_user_id, registered_user_id
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
+            return await self._sso_handler.get_mxid_from_sso(
+                self._auth_provider_id,
+                remote_user_id,
+                user_agent,
+                ip_address,
+                saml_response_to_remapped_user_attributes,
+                grandfather_existing_users,
             )
-            return registered_user_id
 
     def expire_sessions(self):
-        expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+        expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
         for reqid, data in self._outstanding_requests_dict.items():
             if data.creation_time < expire_before:
@@ -476,11 +417,11 @@ class DefaultSamlMappingProvider:
             )
 
         # Use the configured mapper for this mxid_source
-        base_mxid_localpart = self._mxid_mapper(mxid_source)
+        localpart = self._mxid_mapper(mxid_source)
 
         # Append suffix integer if last call to this function failed to produce
-        # a usable mxid
-        localpart = base_mxid_localpart + (str(failures) if failures else "")
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
 
         # Retrieve the display name from the saml response
         # If displayname is None, the mxid_localpart will be used instead