diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/handlers/saml_handler.py | 161 |
1 files changed, 51 insertions, 110 deletions
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 285c481a96..76d4169fe2 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -24,7 +24,8 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.saml2_config import SamlAttributeRequirement -from synapse.http.server import respond_with_html +from synapse.handlers._base import BaseHandler +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) -class MappingException(Exception): - """Used to catch errors when mapping the SAML2 response to a user.""" - - @attr.s(slots=True) class Saml2SessionData: """Data we track about SAML2 sessions""" @@ -57,17 +54,14 @@ class Saml2SessionData: ui_auth_session_id = attr.ib(type=Optional[str], default=None) -class SamlHandler: - def __init__(self, hs: "synapse.server.HomeServer"): - self.hs = hs +class SamlHandler(BaseHandler): + def __init__(self, hs: "HomeServer"): + super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._auth = hs.get_auth() + self._saml_idp_entityid = hs.config.saml2_idp_entityid self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() - self._clock = hs.get_clock() - self._datastore = hs.get_datastore() - self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute @@ -88,26 +82,9 @@ class SamlHandler: self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] # a lock on the mappings - self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) - - def _render_error( - self, request, error: str, error_description: Optional[str] = None - ) -> None: - """Render the error template and respond to the request with it. - - This is used to show errors to the user. The template of this page can - be found under `synapse/res/templates/sso_error.html`. + self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock) - Args: - request: The incoming request from the browser. - We'll respond with an HTML page describing the error. - error: A technical identifier for this error. - error_description: A human-readable description of the error. - """ - html = self._error_template.render( - error=error, error_description=error_description - ) - respond_with_html(request, 400, html) + self._sso_handler = hs.get_sso_handler() def handle_redirect_request( self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None @@ -124,13 +101,13 @@ class SamlHandler: URL to redirect to """ reqid, info = self._saml_client.prepare_for_authenticate( - relay_state=client_redirect_url + entityid=self._saml_idp_entityid, relay_state=client_redirect_url ) # Since SAML sessions timeout it is useful to log when they were created. logger.info("Initiating a new SAML session: %s" % (reqid,)) - now = self._clock.time_msec() + now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( creation_time=now, ui_auth_session_id=ui_auth_session_id, ) @@ -171,12 +148,12 @@ class SamlHandler: # in the (user-visible) exception message, so let's log the exception here # so we can track down the session IDs later. logger.warning(str(e)) - self._render_error( + self._sso_handler.render_error( request, "unsolicited_response", "Unexpected SAML2 login." ) return except Exception as e: - self._render_error( + self._sso_handler.render_error( request, "invalid_response", "Unable to parse SAML2 response: %s." % (e,), @@ -184,7 +161,7 @@ class SamlHandler: return if saml2_auth.not_signed: - self._render_error( + self._sso_handler.render_error( request, "unsigned_respond", "SAML2 response was not signed." ) return @@ -210,15 +187,13 @@ class SamlHandler: # attributes. for requirement in self._saml2_attribute_requirements: if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._render_error( + self._sso_handler.render_error( request, "unauthorised", "You are not authorised to log in here." ) return # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") ip_address = self.hs.get_ip_from_request(request) # Call the mapper to register/login the user @@ -228,7 +203,7 @@ class SamlHandler: ) except MappingException as e: logger.exception("Could not map user") - self._render_error(request, "mapping_error", str(e)) + self._sso_handler.render_error(request, "mapping_error", str(e)) return # Complete the interactive auth session or the login. @@ -274,20 +249,26 @@ class SamlHandler: "Failed to extract remote user id from SAML response" ) - with (await self._mapping_lock.queue(self._auth_provider_id)): - # first of all, check if we already have a mapping for this user - logger.info( - "Looking for existing mapping for user %s:%s", - self._auth_provider_id, - remote_user_id, + async def saml_response_to_remapped_user_attributes( + failures: int, + ) -> UserAttributes: + """ + Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form. + + This is backwards compatibility for abstraction for the SSO handler. + """ + # Call the mapping provider. + result = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, failures, client_redirect_url ) - registered_user_id = await self._datastore.get_user_by_external_id( - self._auth_provider_id, remote_user_id + # Remap some of the results. + return UserAttributes( + localpart=result.get("mxid_localpart"), + display_name=result.get("displayname"), + emails=result.get("emails", []), ) - if registered_user_id is not None: - logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + async def grandfather_existing_users() -> Optional[str]: # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid if ( @@ -296,75 +277,35 @@ class SamlHandler: ): attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] user_id = UserID( - map_username_to_mxid_localpart(attrval), self._hostname + map_username_to_mxid_localpart(attrval), self.server_name ).to_string() - logger.info( + + logger.debug( "Looking for existing account based on mapped %s %s", self._grandfathered_mxid_source_attribute, user_id, ) - users = await self._datastore.get_users_by_id_case_insensitive(user_id) + users = await self.store.get_users_by_id_case_insensitive(user_id) if users: registered_user_id = list(users.keys())[0] logger.info("Grandfathering mapping to %s", registered_user_id) - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id - ) return registered_user_id - # Map saml response to user attributes using the configured mapping provider - for i in range(1000): - attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( - saml2_auth, i, client_redirect_url=client_redirect_url, - ) - - logger.debug( - "Retrieved SAML attributes from user mapping provider: %s " - "(attempt %d)", - attribute_dict, - i, - ) - - localpart = attribute_dict.get("mxid_localpart") - if not localpart: - raise MappingException( - "Error parsing SAML2 response: SAML mapping provider plugin " - "did not return a mxid_localpart value" - ) - - displayname = attribute_dict.get("displayname") - emails = attribute_dict.get("emails", []) - - # Check if this mxid already exists - if not await self._datastore.get_users_by_id_case_insensitive( - UserID(localpart, self._hostname).to_string() - ): - # This mxid is free - break - else: - # Unable to generate a username in 1000 iterations - # Break and return error to the user - raise MappingException( - "Unable to generate a Matrix ID from the SAML response" - ) + return None - logger.info("Mapped SAML user to local part %s", localpart) - - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=displayname, - bind_emails=emails, - user_agent_ips=(user_agent, ip_address), - ) - - await self._datastore.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id + with (await self._mapping_lock.queue(self._auth_provider_id)): + return await self._sso_handler.get_mxid_from_sso( + self._auth_provider_id, + remote_user_id, + user_agent, + ip_address, + saml_response_to_remapped_user_attributes, + grandfather_existing_users, ) - return registered_user_id def expire_sessions(self): - expire_before = self._clock.time_msec() - self._saml2_session_lifetime + expire_before = self.clock.time_msec() - self._saml2_session_lifetime to_expire = set() for reqid, data in self._outstanding_requests_dict.items(): if data.creation_time < expire_before: @@ -476,11 +417,11 @@ class DefaultSamlMappingProvider: ) # Use the configured mapper for this mxid_source - base_mxid_localpart = self._mxid_mapper(mxid_source) + localpart = self._mxid_mapper(mxid_source) # Append suffix integer if last call to this function failed to produce - # a usable mxid - localpart = base_mxid_localpart + (str(failures) if failures else "") + # a usable mxid. + localpart += str(failures) if failures else "" # Retrieve the display name from the saml response # If displayname is None, the mxid_localpart will be used instead |