summary refs log tree commit diff
path: root/synapse/handlers/sso.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sso.py')
-rw-r--r--synapse/handlers/sso.py93
1 files changed, 86 insertions, 7 deletions
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py

index 96ccd991ed..514b1f69d8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, Iterable, + List, Mapping, Optional, Set, @@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError +from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -324,7 +327,8 @@ class SsoHandler: # Check if we already have a mapping for this user. previously_registered_user_id = await self._store.get_user_by_external_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # A match was found, return the user ID. @@ -413,7 +417,8 @@ class SsoHandler: with await self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # Check for grandfathering of users. @@ -458,7 +463,8 @@ class SsoHandler: ) async def _call_attribute_mapper( - self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + self, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], ) -> UserAttributes: """Call the attribute mapper function in a loop, until we get a unique userid""" for i in range(self._MAP_USERNAME_RETRIES): @@ -629,7 +635,8 @@ class SsoHandler: """ user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) user_id_to_verify = await self._auth_handler.get_session_data( @@ -668,7 +675,8 @@ class SsoHandler: # render an error page. html = self._bad_user_template.render( - server_name=self._server_name, user_id_to_verify=user_id_to_verify, + server_name=self._server_name, + user_id_to_verify=user_id_to_verify, ) respond_with_html(request, 200, html) @@ -692,7 +700,9 @@ class SsoHandler: raise SynapseError(400, "unknown session") async def check_username_availability( - self, localpart: str, session_id: str, + self, + localpart: str, + session_id: str, ) -> bool: """Handle an "is username available" callback check @@ -830,7 +840,8 @@ class SsoHandler: ) attributes = UserAttributes( - localpart=session.chosen_localpart, emails=session.emails_to_use, + localpart=session.chosen_localpart, + emails=session.emails_to_use, ) if session.use_display_name: @@ -893,6 +904,41 @@ class SsoHandler: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + def check_required_attributes( + self, + request: SynapseRequest, + attributes: Mapping[str, List[Any]], + attribute_requirements: Iterable[SsoAttributeRequirement], + ) -> bool: + """ + Confirm that the required attributes were present in the SSO response. + + If all requirements are met, this will return True. + + If any requirement is not met, then the request will be finalized by + showing an error page to the user and False will be returned. + + Args: + request: The request to (potentially) respond to. + attributes: The attributes from the SSO IdP. + attribute_requirements: The requirements that attributes must meet. + + Returns: + True if all requirements are met, False if any attribute fails to + meet the requirement. + + """ + # Ensure that the attributes of the logged in user meet the required + # attributes. + for requirement in attribute_requirements: + if not _check_attribute_requirement(attributes, requirement): + self.render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return False + + return True + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie @@ -903,3 +949,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: if not session_id: raise SynapseError(code=400, msg="missing session_id") return session_id.decode("ascii", errors="replace") + + +def _check_attribute_requirement( + attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement +) -> bool: + """Check if SSO attributes meet the proper requirements. + + Args: + attributes: A mapping of attributes to an iterable of one or more values. + requirement: The configured requirement to check. + + Returns: + True if the required attribute was found and had a proper value. + """ + if req.attribute not in attributes: + logger.info("SSO attribute missing: %s", req.attribute) + return False + + # If the requirement is None, the attribute existing is enough. + if req.value is None: + return True + + values = attributes[req.attribute] + if req.value in values: + return True + + logger.info( + "SSO attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + return False