diff options
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r-- | synapse/handlers/sso.py | 220 |
1 files changed, 191 insertions, 29 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index dcc85e9871..b450668f1c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -14,21 +14,31 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Set, +) from urllib.parse import urlencode import attr from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from twisted.web.iweb import IRequest from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent -from synapse.http.server import respond_with_html +from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string @@ -75,6 +85,16 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + + @property + def idp_brand(self) -> Optional[str]: + """Optional branding identifier""" + return None + @abc.abstractmethod async def handle_redirect_request( self, @@ -104,7 +124,7 @@ class UserAttributes: # enter one. localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) - emails = attr.ib(type=List[str], default=attr.Factory(list)) + emails = attr.ib(type=Collection[str], default=attr.Factory(list)) @attr.s(slots=True) @@ -119,7 +139,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name = attr.ib(type=Optional[str]) - emails = attr.ib(type=List[str]) + emails = attr.ib(type=Collection[str]) # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -131,6 +151,12 @@ class UsernameMappingSession: # expiry time for the session, in milliseconds expiry_time_ms = attr.ib(type=int) + # choices made by the user + chosen_localpart = attr.ib(type=Optional[str], default=None) + use_display_name = attr.ib(type=bool, default=True) + emails_to_use = attr.ib(type=Collection[str], default=()) + terms_accepted_version = attr.ib(type=Optional[str], default=None) + # the HTTP cookie used to track the mapping session id USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" @@ -165,6 +191,8 @@ class SsoHandler: # map from idp_id to SsoIdentityProvider self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._consent_at_registration = hs.config.consent.user_consent_at_registration + def register_identity_provider(self, p: SsoIdentityProvider): p_id = p.idp_id assert p_id not in self._identity_providers @@ -230,7 +258,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -238,6 +269,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -247,10 +279,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( @@ -364,6 +405,8 @@ class SsoHandler: to an additional page. (e.g. to prompt for more information) """ + new_user = False + # grab a lock while we try to find a mapping for this user. This seems... # optimistic, especially for implementations that end up redirecting to # interstitial pages. @@ -404,9 +447,14 @@ class SsoHandler: get_request_user_agent(request), request.getClientIP(), ) + new_user = True await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_login_attributes + user_id, + request, + client_redirect_url, + extra_login_attributes, + new_user=new_user, ) async def _call_attribute_mapper( @@ -496,7 +544,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) @@ -624,6 +672,25 @@ class SsoHandler: ) respond_with_html(request, 200, html) + def get_mapping_session(self, session_id: str) -> UsernameMappingSession: + """Look up the given username mapping session + + If it is not found, raises a SynapseError with an http code of 400 + + Args: + session_id: session to look up + Returns: + active mapping session + Raises: + SynapseError if the session is not found/has expired + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if session: + return session + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + async def check_username_availability( self, localpart: str, session_id: str, ) -> bool: @@ -640,12 +707,7 @@ class SsoHandler: # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + self.get_mapping_session(session_id) logger.info( "[session %s] Checking for availability of username %s", @@ -662,7 +724,12 @@ class SsoHandler: return not user_infos async def handle_submit_username_request( - self, request: SynapseRequest, localpart: str, session_id: str + self, + request: SynapseRequest, + session_id: str, + localpart: str, + use_display_name: bool, + emails_to_use: Iterable[str], ) -> None: """Handle a request to the username-picker 'submit' endpoint @@ -672,21 +739,90 @@ class SsoHandler: request: HTTP request localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie + use_display_name: whether the user wants to use the suggested display name + emails_to_use: emails that the user would like to use """ - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + session = self.get_mapping_session(session_id) + + # update the session with the user's choices + session.chosen_localpart = localpart + session.use_display_name = use_display_name + + emails_from_idp = set(session.emails) + filtered_emails = set() # type: Set[str] + + # we iterate through the list rather than just building a set conjunction, so + # that we can log attempts to use unknown addresses + for email in emails_to_use: + if email in emails_from_idp: + filtered_emails.add(email) + else: + logger.warning( + "[session %s] ignoring user request to use unknown email address %r", + session_id, + email, + ) + session.emails_to_use = filtered_emails - logger.info("[session %s] Registering localpart %s", session_id, localpart) + # we may now need to collect consent from the user, in which case, redirect + # to the consent-extraction-unit + if self._consent_at_registration: + redirect_url = b"/_synapse/client/new_user_consent" + + # otherwise, redirect to the completion page + else: + redirect_url = b"/_synapse/client/sso_register" + + respond_with_redirect(request, redirect_url) + + async def handle_terms_accepted( + self, request: Request, session_id: str, terms_version: str + ): + """Handle a request to the new-user 'consent' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + terms_version: the version of the terms which the user viewed and consented + to + """ + logger.info( + "[session %s] User consented to terms version %s", + session_id, + terms_version, + ) + session = self.get_mapping_session(session_id) + session.terms_accepted_version = terms_version + + # we're done; now we can register the user + respond_with_redirect(request, b"/_synapse/client/sso_register") + + async def register_sso_user(self, request: Request, session_id: str) -> None: + """Called once we have all the info we need to register a new user. + + Does so and serves an HTTP response + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + """ + session = self.get_mapping_session(session_id) + + logger.info( + "[session %s] Registering localpart %s", + session_id, + session.chosen_localpart, + ) attributes = UserAttributes( - localpart=localpart, - display_name=session.display_name, - emails=session.emails, + localpart=session.chosen_localpart, emails=session.emails_to_use, ) + if session.use_display_name: + attributes.display_name = session.display_name + # the following will raise a 400 error if the username has been taken in the # meantime. user_id = await self._register_mapped_user( @@ -697,7 +833,12 @@ class SsoHandler: request.getClientIP(), ) - logger.info("[session %s] Registered userid %s", session_id, user_id) + logger.info( + "[session %s] Registered userid %s with attributes %s", + session_id, + user_id, + attributes, + ) # delete the mapping session and the cookie del self._username_mapping_sessions[session_id] @@ -710,11 +851,21 @@ class SsoHandler: path=b"/", ) + auth_result = {} + if session.terms_accepted_version: + # TODO: make this less awful. + auth_result[LoginType.TERMS] = True + + await self._registration_handler.post_registration_actions( + user_id, auth_result, access_token=None + ) + await self._auth_handler.complete_sso_login( user_id, request, session.client_redirect_url, session.extra_login_attributes, + new_user=True, ) def _expire_old_sessions(self): @@ -728,3 +879,14 @@ class SsoHandler: for session_id in to_expire: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + + +def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: + """Extract the session ID from the cookie + + Raises a SynapseError if the cookie isn't found + """ + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + return session_id.decode("ascii", errors="replace") |