summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py120
1 files changed, 68 insertions, 52 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f4434673dc..0e98db22b3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -49,8 +49,13 @@ from synapse.api.errors import (
     UserDeactivatedError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.ui_auth import (
+    INTERACTIVE_AUTH_CHECKERS,
+    UIAuthSessionDataConstants,
+)
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http import get_request_user_agent
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
@@ -62,8 +67,6 @@ from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.threepids import canonicalise_email
 
-from ._base import BaseHandler
-
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
@@ -260,10 +263,6 @@ class AuthHandler(BaseHandler):
         # authenticating for an operation to occur on their account.
         self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
 
-        # The following template is shown after a successful user interactive
-        # authentication session. It tells the user they can close the window.
-        self._sso_auth_success_template = hs.config.sso_auth_success_template
-
         # The following template is shown during the SSO authentication process if
         # the account is deactivated.
         self._sso_account_deactivated_template = (
@@ -284,7 +283,6 @@ class AuthHandler(BaseHandler):
         requester: Requester,
         request: SynapseRequest,
         request_body: Dict[str, Any],
-        clientip: str,
         description: str,
     ) -> Tuple[dict, Optional[str]]:
         """
@@ -301,8 +299,6 @@ class AuthHandler(BaseHandler):
 
             request_body: The body of the request sent by the client
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
@@ -338,10 +334,10 @@ class AuthHandler(BaseHandler):
                 request_body.pop("auth", None)
                 return request_body, None
 
-        user_id = requester.user.to_string()
+        requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
 
         # build a list of supported flows
         supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -349,13 +345,16 @@ class AuthHandler(BaseHandler):
         )
         flows = [[login_type] for login_type in supported_ui_auth_types]
 
+        def get_new_session_data() -> JsonDict:
+            return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
+
         try:
             result, params, session_id = await self.check_ui_auth(
-                flows, request, request_body, clientip, description
+                flows, request, request_body, description, get_new_session_data,
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
             raise
 
         # find the completed login type
@@ -363,14 +362,14 @@ class AuthHandler(BaseHandler):
             if login_type not in result:
                 continue
 
-            user_id = result[login_type]
+            validated_user_id = result[login_type]
             break
         else:
             # this can't happen
             raise Exception("check_auth returned True but no successful login type")
 
         # check that the UI auth matched the access token
-        if user_id != requester.user.to_string():
+        if validated_user_id != requester_user_id:
             raise AuthError(403, "Invalid auth")
 
         # Note that the access token has been validated.
@@ -402,13 +401,9 @@ class AuthHandler(BaseHandler):
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
-        if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
-            if await self.store.get_external_ids_by_user(user.to_string()):
-                ui_auth_types.add(LoginType.SSO)
-
-        # Our CAS impl does not (yet) correctly register users in user_external_ids,
-        # so always offer that if it's available.
-        if self.hs.config.cas.cas_enabled:
+        if await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user.to_string()
+        ):
             ui_auth_types.add(LoginType.SSO)
 
         return ui_auth_types
@@ -426,8 +421,8 @@ class AuthHandler(BaseHandler):
         flows: List[List[str]],
         request: SynapseRequest,
         clientdict: Dict[str, Any],
-        clientip: str,
         description: str,
+        get_new_session_data: Optional[Callable[[], JsonDict]] = None,
     ) -> Tuple[dict, dict, str]:
         """
         Takes a dictionary sent by the client in the login / registration
@@ -448,11 +443,16 @@ class AuthHandler(BaseHandler):
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
-            clientip: The IP address of the client.
-
             description: A human readable string to be displayed to the user that
                          describes the operation happening on their account.
 
+            get_new_session_data:
+                an optional callback which will be called when starting a new session.
+                it should return data to be stored as part of the session.
+
+                The keys of the returned data should be entries in
+                UIAuthSessionDataConstants.
+
         Returns:
             A tuple of (creds, params, session_id).
 
@@ -480,10 +480,15 @@ class AuthHandler(BaseHandler):
 
         # If there's no session ID, create a new session.
         if not sid:
+            new_session_data = get_new_session_data() if get_new_session_data else {}
+
             session = await self.store.create_ui_auth_session(
                 clientdict, uri, method, description
             )
 
+            for k, v in new_session_data.items():
+                await self.set_session_data(session.session_id, k, v)
+
         else:
             try:
                 session = await self.store.get_ui_auth_session(sid)
@@ -539,7 +544,8 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.get_user_agent("")
+        user_agent = get_request_user_agent(request)
+        clientip = request.getClientIP()
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -644,7 +650,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key to store the data under. An entry from
+                UIAuthSessionDataConstants.
             value: The data to store
         """
         try:
@@ -660,7 +667,8 @@ class AuthHandler(BaseHandler):
 
         Args:
             session_id: The ID of this session as returned from check_auth
-            key: The key to store the data under
+            key: The key the data was stored under. An entry from
+                UIAuthSessionDataConstants.
             default: Value to return if the key has not been set
         """
         try:
@@ -1334,12 +1342,12 @@ class AuthHandler(BaseHandler):
         else:
             return False
 
-    async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+    async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
         """
         Get the HTML for the SSO redirect confirmation page.
 
         Args:
-            redirect_url: The URL to redirect to the SSO provider.
+            request: The incoming HTTP request
             session_id: The user interactive authentication session ID.
 
         Returns:
@@ -1349,30 +1357,38 @@ class AuthHandler(BaseHandler):
             session = await self.store.get_ui_auth_session(session_id)
         except StoreError:
             raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
-        return self._sso_auth_confirm_template.render(
-            description=session.description, redirect_url=redirect_url,
+
+        user_id_to_verify = await self.get_session_data(
+            session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+        )  # type: str
+
+        idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+            user_id_to_verify
         )
 
-    async def complete_sso_ui_auth(
-        self, registered_user_id: str, session_id: str, request: Request,
-    ):
-        """Having figured out a mxid for this user, complete the HTTP request
+        if not idps:
+            # we checked that the user had some remote identities before offering an SSO
+            # flow, so either it's been deleted or the client has requested SSO despite
+            # it not being offered.
+            raise SynapseError(400, "User has no SSO identities")
 
-        Args:
-            registered_user_id: The registered user ID to complete SSO login for.
-            session_id: The ID of the user-interactive auth session.
-            request: The request to complete.
-        """
-        # Mark the stage of the authentication as successful.
-        # Save the user who authenticated with SSO, this will be used to ensure
-        # that the account be modified is also the person who logged in.
-        await self.store.mark_ui_auth_stage_complete(
-            session_id, LoginType.SSO, registered_user_id
+        # for now, just pick one
+        idp_id, sso_auth_provider = next(iter(idps.items()))
+        if len(idps) > 0:
+            logger.warning(
+                "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+                "picking %r",
+                user_id_to_verify,
+                idp_id,
+            )
+
+        redirect_url = await sso_auth_provider.handle_redirect_request(
+            request, None, session_id
         )
 
-        # Render the HTML and return.
-        html = self._sso_auth_success_template
-        respond_with_html(request, 200, html)
+        return self._sso_auth_confirm_template.render(
+            description=session.description, redirect_url=redirect_url,
+        )
 
     async def complete_sso_login(
         self,
@@ -1488,8 +1504,8 @@ class AuthHandler(BaseHandler):
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({param_name: param})
+        query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+        query.append((param_name, param))
         url_parts[4] = urllib.parse.urlencode(query)
         return urllib.parse.urlunparse(url_parts)