summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-02-01 13:15:51 +0000
committerGitHub <noreply@github.com>2021-02-01 13:15:51 +0000
commitf78d07bf005f7212bcc74256721677a3b255ea0e (patch)
treebd70a722503feece7aff57df2bc71b25272b95b0 /synapse/handlers/sso.py
parentAdd 'brand' field to MSC2858 response (#9242) (diff)
downloadsynapse-f78d07bf005f7212bcc74256721677a3b255ea0e.tar.xz
Split out a separate endpoint to complete SSO registration (#9262)
There are going to be a couple of paths to get to the final step of SSO reg, and I want the URL in the browser to consistent. So, let's move the final step onto a separate path, which we redirect to.
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py81
1 files changed, 66 insertions, 15 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 3308b037d2..50c5ae142a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -21,12 +21,13 @@ import attr
 from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
+from twisted.web.iweb import IRequest
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
-from synapse.http.server import respond_with_html
+from synapse.http.server import respond_with_html, respond_with_redirect
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
@@ -141,6 +142,9 @@ class UsernameMappingSession:
     # expiry time for the session, in milliseconds
     expiry_time_ms = attr.ib(type=int)
 
+    # choices made by the user
+    chosen_localpart = attr.ib(type=Optional[str], default=None)
+
 
 # the HTTP cookie used to track the mapping session id
 USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -647,6 +651,25 @@ class SsoHandler:
         )
         respond_with_html(request, 200, html)
 
+    def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+        """Look up the given username mapping session
+
+        If it is not found, raises a SynapseError with an http code of 400
+
+        Args:
+            session_id: session to look up
+        Returns:
+            active mapping session
+        Raises:
+            SynapseError if the session is not found/has expired
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if session:
+            return session
+        logger.info("Couldn't find session id %s", session_id)
+        raise SynapseError(400, "unknown session")
+
     async def check_username_availability(
         self, localpart: str, session_id: str,
     ) -> bool:
@@ -663,12 +686,7 @@ class SsoHandler:
 
         # make sure that there is a valid mapping session, to stop people dictionary-
         # scanning for accounts
-
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        self.get_mapping_session(session_id)
 
         logger.info(
             "[session %s] Checking for availability of username %s",
@@ -696,16 +714,33 @@ class SsoHandler:
             localpart: localpart requested by the user
             session_id: ID of the username mapping session, extracted from a cookie
         """
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        session = self.get_mapping_session(session_id)
+
+        # update the session with the user's choices
+        session.chosen_localpart = localpart
+
+        # we're done; now we can register the user
+        respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+    async def register_sso_user(self, request: Request, session_id: str) -> None:
+        """Called once we have all the info we need to register a new user.
 
-        logger.info("[session %s] Registering localpart %s", session_id, localpart)
+        Does so and serves an HTTP response
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        session = self.get_mapping_session(session_id)
+
+        logger.info(
+            "[session %s] Registering localpart %s",
+            session_id,
+            session.chosen_localpart,
+        )
 
         attributes = UserAttributes(
-            localpart=localpart,
+            localpart=session.chosen_localpart,
             display_name=session.display_name,
             emails=session.emails,
         )
@@ -720,7 +755,12 @@ class SsoHandler:
             request.getClientIP(),
         )
 
-        logger.info("[session %s] Registered userid %s", session_id, user_id)
+        logger.info(
+            "[session %s] Registered userid %s with attributes %s",
+            session_id,
+            user_id,
+            attributes,
+        )
 
         # delete the mapping session and the cookie
         del self._username_mapping_sessions[session_id]
@@ -751,3 +791,14 @@ class SsoHandler:
         for session_id in to_expire:
             logger.info("Expiring mapping session %s", session_id)
             del self._username_mapping_sessions[session_id]
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+    """Extract the session ID from the cookie
+
+    Raises a SynapseError if the cookie isn't found
+    """
+    session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+    if not session_id:
+        raise SynapseError(code=400, msg="missing session_id")
+    return session_id.decode("ascii", errors="replace")