summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py86
1 files changed, 78 insertions, 8 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 2da1ea2223..d493327a10 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -22,7 +22,10 @@ from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
 
+from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
+from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@@ -72,6 +75,11 @@ class SsoIdentityProvider(Protocol):
     def idp_name(self) -> str:
         """User-facing name for this provider"""
 
+    @property
+    def idp_icon(self) -> Optional[str]:
+        """Optional MXC URI for user-facing icon"""
+        return None
+
     @abc.abstractmethod
     async def handle_redirect_request(
         self,
@@ -145,8 +153,13 @@ class SsoHandler:
         self._store = hs.get_datastore()
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
-        self._error_template = hs.config.sso_error_template
         self._auth_handler = hs.get_auth_handler()
+        self._error_template = hs.config.sso_error_template
+        self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+        # The following template is shown after a successful user interactive
+        # authentication session. It tells the user they can close the window.
+        self._sso_auth_success_template = hs.config.sso_auth_success_template
 
         # a lock on the mappings
         self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@@ -166,6 +179,37 @@ class SsoHandler:
         """Get the configured identity providers"""
         return self._identity_providers
 
+    async def get_identity_providers_for_user(
+        self, user_id: str
+    ) -> Mapping[str, SsoIdentityProvider]:
+        """Get the SsoIdentityProviders which a user has used
+
+        Given a user id, get the identity providers that that user has used to log in
+        with in the past (and thus could use to re-identify themselves for UI Auth).
+
+        Args:
+            user_id: MXID of user to look up
+
+        Raises:
+            a map of idp_id to SsoIdentityProvider
+        """
+        external_ids = await self._store.get_external_ids_by_user(user_id)
+
+        valid_idps = {}
+        for idp_id, _ in external_ids:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                logger.warning(
+                    "User %r has an SSO mapping for IdP %r, but this is no longer "
+                    "configured.",
+                    user_id,
+                    idp_id,
+                )
+            else:
+                valid_idps[idp_id] = idp
+
+        return valid_idps
+
     def render_error(
         self,
         request: Request,
@@ -362,7 +406,7 @@ class SsoHandler:
                     attributes,
                     auth_provider_id,
                     remote_user_id,
-                    request.get_user_agent(""),
+                    get_request_user_agent(request),
                     request.getClientIP(),
                 )
 
@@ -545,19 +589,45 @@ class SsoHandler:
             auth_provider_id, remote_user_id,
         )
 
+        user_id_to_verify = await self._auth_handler.get_session_data(
+            ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
         if not user_id:
             logger.warning(
                 "Remote user %s/%s has not previously logged in here: UIA will fail",
                 auth_provider_id,
                 remote_user_id,
             )
-            # Let the UIA flow handle this the same as if they presented creds for a
-            # different user.
-            user_id = ""
+        elif user_id != user_id_to_verify:
+            logger.warning(
+                "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+                user_id,
+            )
+        else:
+            # success!
+            # Mark the stage of the authentication as successful.
+            await self._store.mark_ui_auth_stage_complete(
+                ui_auth_session_id, LoginType.SSO, user_id
+            )
+
+            # Render the HTML confirmation page and return.
+            html = self._sso_auth_success_template
+            respond_with_html(request, 200, html)
+            return
+
+        # the user_id didn't match: mark the stage of the authentication as unsuccessful
+        await self._store.mark_ui_auth_stage_complete(
+            ui_auth_session_id, LoginType.SSO, ""
+        )
 
-        await self._auth_handler.complete_sso_ui_auth(
-            user_id, ui_auth_session_id, request
+        # render an error page.
+        html = self._bad_user_template.render(
+            server_name=self._server_name, user_id_to_verify=user_id_to_verify,
         )
+        respond_with_html(request, 200, html)
 
     async def check_username_availability(
         self, localpart: str, session_id: str,
@@ -628,7 +698,7 @@ class SsoHandler:
             attributes,
             session.auth_provider_id,
             session.remote_user_id,
-            request.get_user_agent(""),
+            get_request_user_agent(request),
             request.getClientIP(),
         )