summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8881.misc1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/auth.py11
-rw-r--r--synapse/handlers/oidc_handler.py44
-rw-r--r--synapse/handlers/saml_handler.py64
-rw-r--r--synapse/handlers/sso.py59
7 files changed, 144 insertions, 40 deletions
diff --git a/changelog.d/8881.misc b/changelog.d/8881.misc
new file mode 100644
index 0000000000..07d3f30fb2
--- /dev/null
+++ b/changelog.d/8881.misc
@@ -0,0 +1 @@
+Simplify logic for handling user-interactive-auth via single-sign-on servers.
diff --git a/mypy.ini b/mypy.ini
index 59144be469..12408b8d95 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -43,6 +43,7 @@ files =
   synapse/handlers/room_member.py,
   synapse/handlers/room_member_worker.py,
   synapse/handlers/saml_handler.py,
+  synapse/handlers/sso.py,
   synapse/handlers/sync.py,
   synapse/handlers/ui_auth,
   synapse/http/client.py,
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..d29b066a56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
 class BaseHandler:
     """
     Common base class for the event handlers.
+
+    Deprecated: new code should not use this. Instead, Handler classes should define the
+    fields they actually need. The utility methods should either be factored out to
+    standalone helper functions, or to different Handler classes.
     """
 
     def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2e72298e05..afae6d3272 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -36,6 +36,8 @@ import attr
 import bcrypt
 import pymacaroons
 
+from twisted.web.http import Request
+
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     AuthError,
@@ -1331,15 +1333,14 @@ class AuthHandler(BaseHandler):
         )
 
     async def complete_sso_ui_auth(
-        self, registered_user_id: str, session_id: str, request: SynapseRequest,
+        self, registered_user_id: str, session_id: str, request: Request,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
         Args:
             registered_user_id: The registered user ID to complete SSO login for.
+            session_id: The ID of the user-interactive auth session.
             request: The request to complete.
-            client_redirect_url: The URL to which to redirect the user at the end of the
-                process.
         """
         # Mark the stage of the authentication as successful.
         # Save the user who authenticated with SSO, this will be used to ensure
@@ -1355,7 +1356,7 @@ class AuthHandler(BaseHandler):
     async def complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
     ):
@@ -1383,7 +1384,7 @@ class AuthHandler(BaseHandler):
     def _complete_sso_login(
         self,
         registered_user_id: str,
-        request: SynapseRequest,
+        request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
     ):
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..f626117f76 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
                 self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
+        # first check if we're doing a UIA
+        if ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_userinfo(userinfo)
+            except Exception as e:
+                logger.exception("Could not extract remote user id")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+            )
+
+        # otherwise, it's a login
+
         # Pull out the user-agent and IP from the request.
         user_agent = request.get_user_agent("")
         ip_address = self.hs.get_ip_from_request(request)
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
             extra_attributes = await get_extra_attributes(userinfo, token)
 
         # and finally complete the login
-        if ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, ui_auth_session_id, request
-            )
-        else:
-            await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url, extra_attributes
-            )
+        await self._auth_handler.complete_sso_login(
+            user_id, request, client_redirect_url, extra_attributes
+        )
 
     def _generate_oidc_session_token(
         self,
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
             The mxid of the user
         """
         try:
-            remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+            remote_user_id = self._remote_id_from_userinfo(userinfo)
         except Exception as e:
             raise MappingException(
                 "Failed to extract subject from OIDC response: %s" % (e,)
             )
-        # Some OIDC providers use integer IDs, but Synapse expects external IDs
-        # to be strings.
-        remote_user_id = str(remote_user_id)
 
         # Older mapping providers don't accept the `failures` argument, so we
         # try and detect support.
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
             grandfather_existing_users,
         )
 
+    def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+        """Extract the unique remote id from an OIDC UserInfo block
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+        Returns:
+            remote user id
+        """
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+        # Some OIDC providers use integer IDs, but Synapse expects external IDs
+        # to be strings.
+        return str(remote_user_id)
+
 
 UserAttributeDict = TypedDict(
     "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..5846f08609 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -183,6 +183,24 @@ class SamlHandler(BaseHandler):
             saml2_auth.in_response_to, None
         )
 
+        # first check if we're doing a UIA
+        if current_session and current_session.ui_auth_session_id:
+            try:
+                remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+            except MappingException as e:
+                logger.exception("Failed to extract remote user id from SAML response")
+                self._sso_handler.render_error(request, "mapping_error", str(e))
+                return
+
+            return await self._sso_handler.complete_sso_ui_auth_request(
+                self._auth_provider_id,
+                remote_user_id,
+                current_session.ui_auth_session_id,
+                request,
+            )
+
+        # otherwise, we're handling a login request.
+
         # Ensure that the attributes of the logged in user meet the required
         # attributes.
         for requirement in self._saml2_attribute_requirements:
@@ -206,14 +224,7 @@ class SamlHandler(BaseHandler):
             self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
-        # Complete the interactive auth session or the login.
-        if current_session and current_session.ui_auth_session_id:
-            await self._auth_handler.complete_sso_ui_auth(
-                user_id, current_session.ui_auth_session_id, request
-            )
-
-        else:
-            await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+        await self._auth_handler.complete_sso_login(user_id, request, relay_state)
 
     async def _map_saml_response_to_user(
         self,
@@ -239,16 +250,10 @@ class SamlHandler(BaseHandler):
             RedirectException: some mapping providers may raise this if they need
                 to redirect to an interstitial page.
         """
-
-        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+        remote_user_id = self._remote_id_from_saml_response(
             saml2_auth, client_redirect_url
         )
 
-        if not remote_user_id:
-            raise MappingException(
-                "Failed to extract remote user id from SAML response"
-            )
-
         async def saml_response_to_remapped_user_attributes(
             failures: int,
         ) -> UserAttributes:
@@ -304,6 +309,35 @@ class SamlHandler(BaseHandler):
                 grandfather_existing_users,
             )
 
+    def _remote_id_from_saml_response(
+        self,
+        saml2_auth: saml2.response.AuthnResponse,
+        client_redirect_url: Optional[str],
+    ) -> str:
+        """Extract the unique remote id from a SAML2 AuthnResponse
+
+        Args:
+            saml2_auth: The parsed SAML2 response.
+            client_redirect_url: The redirect URL passed in by the client.
+        Returns:
+            remote user id
+
+        Raises:
+            MappingException if there was an error extracting the user id
+        """
+        # It's not obvious why we need to pass in the redirect URI to the mapping
+        # provider, but we do :/
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+            saml2_auth, client_redirect_url
+        )
+
+        if not remote_user_id:
+            raise MappingException(
+                "Failed to extract remote user id from SAML response"
+            )
+
+        return remote_user_id
+
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..e24767b921 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -17,8 +17,9 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
 
 import attr
 
+from twisted.web.http import Request
+
 from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
 from synapse.http.server import respond_with_html
 from synapse.types import UserID, contains_invalid_mxid_characters
 
@@ -42,14 +43,16 @@ class UserAttributes:
     emails = attr.ib(type=List[str], default=attr.Factory(list))
 
 
-class SsoHandler(BaseHandler):
+class SsoHandler:
     # The number of attempts to ask the mapping provider for when generating an MXID.
     _MAP_USERNAME_RETRIES = 1000
 
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        self._store = hs.get_datastore()
+        self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
         self._error_template = hs.config.sso_error_template
+        self._auth_handler = hs.get_auth_handler()
 
     def render_error(
         self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +98,7 @@ class SsoHandler(BaseHandler):
         )
 
         # Check if we already have a mapping for this user.
-        previously_registered_user_id = await self.store.get_user_by_external_id(
+        previously_registered_user_id = await self._store.get_user_by_external_id(
             auth_provider_id, remote_user_id,
         )
 
@@ -181,7 +184,7 @@ class SsoHandler(BaseHandler):
             previously_registered_user_id = await grandfather_existing_users()
             if previously_registered_user_id:
                 # Future logins should also match this user ID.
-                await self.store.record_user_external_id(
+                await self._store.record_user_external_id(
                     auth_provider_id, remote_user_id, previously_registered_user_id
                 )
                 return previously_registered_user_id
@@ -214,8 +217,8 @@ class SsoHandler(BaseHandler):
                 )
 
             # Check if this mxid already exists
-            user_id = UserID(attributes.localpart, self.server_name).to_string()
-            if not await self.store.get_users_by_id_case_insensitive(user_id):
+            user_id = UserID(attributes.localpart, self._server_name).to_string()
+            if not await self._store.get_users_by_id_case_insensitive(user_id):
                 # This mxid is free
                 break
         else:
@@ -238,7 +241,47 @@ class SsoHandler(BaseHandler):
             user_agent_ips=[(user_agent, ip_address)],
         )
 
-        await self.store.record_user_external_id(
+        await self._store.record_user_external_id(
             auth_provider_id, remote_user_id, registered_user_id
         )
         return registered_user_id
+
+    async def complete_sso_ui_auth_request(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        ui_auth_session_id: str,
+        request: Request,
+    ) -> None:
+        """
+        Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+        Note that this requires that the user is mapped in the "user_external_ids"
+        table. This will be the case if they have ever logged in via SAML or OIDC in
+        recentish synapse versions, but may not be for older users.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The unique identifier from the SSO provider.
+            ui_auth_session_id: The ID of the user-interactive auth session.
+            request: The request to complete.
+        """
+
+        user_id = await self.get_sso_user_by_remote_user_id(
+            auth_provider_id, remote_user_id,
+        )
+
+        if not user_id:
+            logger.warning(
+                "Remote user %s/%s has not previously logged in here: UIA will fail",
+                auth_provider_id,
+                remote_user_id,
+            )
+            # Let the UIA flow handle this the same as if they presented creds for a
+            # different user.
+            user_id = ""
+
+        await self._auth_handler.complete_sso_ui_auth(
+            user_id, ui_auth_session_id, request
+        )