diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2f5b2b61aa..4f881a439a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -50,7 +50,10 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers._base import BaseHandler
-from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers.ui_auth import (
+ INTERACTIVE_AUTH_CHECKERS,
+ UIAuthSessionDataConstants,
+)
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http import get_request_user_agent
from synapse.http.server import finish_request, respond_with_html
@@ -335,10 +338,10 @@ class AuthHandler(BaseHandler):
request_body.pop("auth", None)
return request_body, None
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
+ self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -346,13 +349,16 @@ class AuthHandler(BaseHandler):
)
flows = [[login_type] for login_type in supported_ui_auth_types]
+ def get_new_session_data() -> JsonDict:
+ return {UIAuthSessionDataConstants.REQUEST_USER_ID: requester_user_id}
+
try:
result, params, session_id = await self.check_ui_auth(
- flows, request, request_body, description
+ flows, request, request_body, description, get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
+ self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
raise
# find the completed login type
@@ -360,14 +366,14 @@ class AuthHandler(BaseHandler):
if login_type not in result:
continue
- user_id = result[login_type]
+ validated_user_id = result[login_type]
break
else:
# this can't happen
raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
- if user_id != requester.user.to_string():
+ if validated_user_id != requester_user_id:
raise AuthError(403, "Invalid auth")
# Note that the access token has been validated.
@@ -399,13 +405,9 @@ class AuthHandler(BaseHandler):
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
# from sso to mxid.
- if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
- if await self.store.get_external_ids_by_user(user.to_string()):
- ui_auth_types.add(LoginType.SSO)
-
- # Our CAS impl does not (yet) correctly register users in user_external_ids,
- # so always offer that if it's available.
- if self.hs.config.cas.cas_enabled:
+ if await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user.to_string()
+ ):
ui_auth_types.add(LoginType.SSO)
return ui_auth_types
@@ -424,6 +426,7 @@ class AuthHandler(BaseHandler):
request: SynapseRequest,
clientdict: Dict[str, Any],
description: str,
+ get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
@@ -447,6 +450,13 @@ class AuthHandler(BaseHandler):
description: A human readable string to be displayed to the user that
describes the operation happening on their account.
+ get_new_session_data:
+ an optional callback which will be called when starting a new session.
+ it should return data to be stored as part of the session.
+
+ The keys of the returned data should be entries in
+ UIAuthSessionDataConstants.
+
Returns:
A tuple of (creds, params, session_id).
@@ -474,10 +484,15 @@ class AuthHandler(BaseHandler):
# If there's no session ID, create a new session.
if not sid:
+ new_session_data = get_new_session_data() if get_new_session_data else {}
+
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
+ for k, v in new_session_data.items():
+ await self.set_session_data(session.session_id, k, v)
+
else:
try:
session = await self.store.get_ui_auth_session(sid)
@@ -639,7 +654,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key to store the data under. An entry from
+ UIAuthSessionDataConstants.
value: The data to store
"""
try:
@@ -655,7 +671,8 @@ class AuthHandler(BaseHandler):
Args:
session_id: The ID of this session as returned from check_auth
- key: The key to store the data under
+ key: The key the data was stored under. An entry from
+ UIAuthSessionDataConstants.
default: Value to return if the key has not been set
"""
try:
@@ -1329,12 +1346,12 @@ class AuthHandler(BaseHandler):
else:
return False
- async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ async def start_sso_ui_auth(self, request: SynapseRequest, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
Args:
- redirect_url: The URL to redirect to the SSO provider.
+ request: The incoming HTTP request
session_id: The user interactive authentication session ID.
Returns:
@@ -1344,6 +1361,35 @@ class AuthHandler(BaseHandler):
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+
+ user_id_to_verify = await self.get_session_data(
+ session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
+ idps = await self.hs.get_sso_handler().get_identity_providers_for_user(
+ user_id_to_verify
+ )
+
+ if not idps:
+ # we checked that the user had some remote identities before offering an SSO
+ # flow, so either it's been deleted or the client has requested SSO despite
+ # it not being offered.
+ raise SynapseError(400, "User has no SSO identities")
+
+ # for now, just pick one
+ idp_id, sso_auth_provider = next(iter(idps.items()))
+ if len(idps) > 0:
+ logger.warning(
+ "User %r has previously logged in with multiple SSO IdPs; arbitrarily "
+ "picking %r",
+ user_id_to_verify,
+ idp_id,
+ )
+
+ redirect_url = await sso_auth_provider.handle_redirect_request(
+ request, None, session_id
+ )
+
return self._sso_auth_confirm_template.render(
description=session.description, redirect_url=redirect_url,
)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 740df7e4a0..d096e0b091 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -167,6 +167,37 @@ class SsoHandler:
"""Get the configured identity providers"""
return self._identity_providers
+ async def get_identity_providers_for_user(
+ self, user_id: str
+ ) -> Mapping[str, SsoIdentityProvider]:
+ """Get the SsoIdentityProviders which a user has used
+
+ Given a user id, get the identity providers that that user has used to log in
+ with in the past (and thus could use to re-identify themselves for UI Auth).
+
+ Args:
+ user_id: MXID of user to look up
+
+ Raises:
+ a map of idp_id to SsoIdentityProvider
+ """
+ external_ids = await self._store.get_external_ids_by_user(user_id)
+
+ valid_idps = {}
+ for idp_id, _ in external_ids:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ logger.warning(
+ "User %r has an SSO mapping for IdP %r, but this is no longer "
+ "configured.",
+ user_id,
+ idp_id,
+ )
+ else:
+ valid_idps[idp_id] = idp
+
+ return valid_idps
+
def render_error(
self,
request: Request,
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
index 824f37f8f8..a68d5e790e 100644
--- a/synapse/handlers/ui_auth/__init__.py
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -20,3 +20,18 @@ TODO: move more stuff out of AuthHandler in here.
"""
from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
+
+
+class UIAuthSessionDataConstants:
+ """Constants for use with AuthHandler.set_session_data"""
+
+ # used during registration and password reset to store a hashed copy of the
+ # password, so that the client does not need to submit it each time.
+ PASSWORD_HASH = "password_hash"
+
+ # used during registration to store the mxid of the registered user
+ REGISTERED_USER_ID = "registered_user_id"
+
+ # used by validate_user_via_ui_auth to store the mxid of the user we are validating
+ # for.
+ REQUEST_USER_ID = "request_user_id"
|