diff options
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r-- | synapse/handlers/sso.py | 90 |
1 files changed, 87 insertions, 3 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index b450668f1c..a63fd52485 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -16,10 +16,12 @@ import abc import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, Iterable, + List, Mapping, Optional, Set, @@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError +from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -742,7 +745,11 @@ class SsoHandler: use_display_name: whether the user wants to use the suggested display name emails_to_use: emails that the user would like to use """ - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return # update the session with the user's choices session.chosen_localpart = localpart @@ -793,7 +800,12 @@ class SsoHandler: session_id, terms_version, ) - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + session.terms_accepted_version = terms_version # we're done; now we can register the user @@ -808,7 +820,11 @@ class SsoHandler: request: HTTP request session_id: ID of the username mapping session, extracted from a cookie """ - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return logger.info( "[session %s] Registering localpart %s", @@ -880,6 +896,41 @@ class SsoHandler: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + def check_required_attributes( + self, + request: SynapseRequest, + attributes: Mapping[str, List[Any]], + attribute_requirements: Iterable[SsoAttributeRequirement], + ) -> bool: + """ + Confirm that the required attributes were present in the SSO response. + + If all requirements are met, this will return True. + + If any requirement is not met, then the request will be finalized by + showing an error page to the user and False will be returned. + + Args: + request: The request to (potentially) respond to. + attributes: The attributes from the SSO IdP. + attribute_requirements: The requirements that attributes must meet. + + Returns: + True if all requirements are met, False if any attribute fails to + meet the requirement. + + """ + # Ensure that the attributes of the logged in user meet the required + # attributes. + for requirement in attribute_requirements: + if not _check_attribute_requirement(attributes, requirement): + self.render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return False + + return True + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie @@ -890,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: if not session_id: raise SynapseError(code=400, msg="missing session_id") return session_id.decode("ascii", errors="replace") + + +def _check_attribute_requirement( + attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement +) -> bool: + """Check if SSO attributes meet the proper requirements. + + Args: + attributes: A mapping of attributes to an iterable of one or more values. + requirement: The configured requirement to check. + + Returns: + True if the required attribute was found and had a proper value. + """ + if req.attribute not in attributes: + logger.info("SSO attribute missing: %s", req.attribute) + return False + + # If the requirement is None, the attribute existing is enough. + if req.value is None: + return True + + values = attributes[req.attribute] + if req.value in values: + return True + + logger.info( + "SSO attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + return False |