diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 3308b037d2..50c5ae142a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -21,12 +21,13 @@ import attr
from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
+from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
-from synapse.http.server import respond_with_html
+from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
@@ -141,6 +142,9 @@ class UsernameMappingSession:
# expiry time for the session, in milliseconds
expiry_time_ms = attr.ib(type=int)
+ # choices made by the user
+ chosen_localpart = attr.ib(type=Optional[str], default=None)
+
# the HTTP cookie used to track the mapping session id
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -647,6 +651,25 @@ class SsoHandler:
)
respond_with_html(request, 200, html)
+ def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+ """Look up the given username mapping session
+
+ If it is not found, raises a SynapseError with an http code of 400
+
+ Args:
+ session_id: session to look up
+ Returns:
+ active mapping session
+ Raises:
+ SynapseError if the session is not found/has expired
+ """
+ self._expire_old_sessions()
+ session = self._username_mapping_sessions.get(session_id)
+ if session:
+ return session
+ logger.info("Couldn't find session id %s", session_id)
+ raise SynapseError(400, "unknown session")
+
async def check_username_availability(
self, localpart: str, session_id: str,
) -> bool:
@@ -663,12 +686,7 @@ class SsoHandler:
# make sure that there is a valid mapping session, to stop people dictionary-
# scanning for accounts
-
- self._expire_old_sessions()
- session = self._username_mapping_sessions.get(session_id)
- if not session:
- logger.info("Couldn't find session id %s", session_id)
- raise SynapseError(400, "unknown session")
+ self.get_mapping_session(session_id)
logger.info(
"[session %s] Checking for availability of username %s",
@@ -696,16 +714,33 @@ class SsoHandler:
localpart: localpart requested by the user
session_id: ID of the username mapping session, extracted from a cookie
"""
- self._expire_old_sessions()
- session = self._username_mapping_sessions.get(session_id)
- if not session:
- logger.info("Couldn't find session id %s", session_id)
- raise SynapseError(400, "unknown session")
+ session = self.get_mapping_session(session_id)
+
+ # update the session with the user's choices
+ session.chosen_localpart = localpart
+
+ # we're done; now we can register the user
+ respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+ async def register_sso_user(self, request: Request, session_id: str) -> None:
+ """Called once we have all the info we need to register a new user.
- logger.info("[session %s] Registering localpart %s", session_id, localpart)
+ Does so and serves an HTTP response
+
+ Args:
+ request: HTTP request
+ session_id: ID of the username mapping session, extracted from a cookie
+ """
+ session = self.get_mapping_session(session_id)
+
+ logger.info(
+ "[session %s] Registering localpart %s",
+ session_id,
+ session.chosen_localpart,
+ )
attributes = UserAttributes(
- localpart=localpart,
+ localpart=session.chosen_localpart,
display_name=session.display_name,
emails=session.emails,
)
@@ -720,7 +755,12 @@ class SsoHandler:
request.getClientIP(),
)
- logger.info("[session %s] Registered userid %s", session_id, user_id)
+ logger.info(
+ "[session %s] Registered userid %s with attributes %s",
+ session_id,
+ user_id,
+ attributes,
+ )
# delete the mapping session and the cookie
del self._username_mapping_sessions[session_id]
@@ -751,3 +791,14 @@ class SsoHandler:
for session_id in to_expire:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+ """Extract the session ID from the cookie
+
+ Raises a SynapseError if the cookie isn't found
+ """
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+ if not session_id:
+ raise SynapseError(code=400, msg="missing session_id")
+ return session_id.decode("ascii", errors="replace")
|