From a8ac40445c98b9e1fc2538d7d4ec49c80b0298ac Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 13 Sep 2019 15:20:49 +0100 Subject: Record mappings from saml users in an external table We want to assign unique mxids to saml users based on an incrementing suffix. For that to work, we need to record the allocated mxid in a separate table. --- synapse/handlers/saml_handler.py | 103 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 8 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a1ce6929cf..5fa8272dc9 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -21,6 +21,8 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler +from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -29,12 +31,26 @@ class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._sso_auth_handler = SSOAuthHandler(hs) + self._registration_handler = hs.get_registration_handler() + + self._clock = hs.get_clock() + self._datastore = hs.get_datastore() + self._hostname = hs.hostname + self._saml2_session_lifetime = hs.config.saml2_session_lifetime + self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute + self._grandfathered_mxid_source_attribute = ( + hs.config.saml2_grandfathered_mxid_source_attribute + ) + self._mxid_mapper = hs.config.saml2_mxid_mapper + + # identifier for the external_ids table + self._auth_provider_id = "saml" # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} - self._clock = hs.get_clock() - self._saml2_session_lifetime = hs.config.saml2_session_lifetime + # a lock on the mappings + self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) def handle_redirect_request(self, client_redirect_url): """Handle an incoming request to /login/sso/redirect @@ -60,7 +76,7 @@ class SamlHandler: # this shouldn't happen! raise Exception("prepare_for_authenticate didn't return a Location header") - def handle_saml_response(self, request): + async def handle_saml_response(self, request): """Handle an incoming request to /_matrix/saml2/authn_response Args: @@ -77,6 +93,10 @@ class SamlHandler: # the dict. self.expire_sessions() + user_id = await self._map_saml_response_to_user(resp_bytes) + self._sso_auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user(self, resp_bytes): try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -91,18 +111,85 @@ class SamlHandler: logger.warning("SAML2 response was not signed") raise SynapseError(400, "SAML2 response was not signed") - if "uid" not in saml2_auth.ava: + try: + remote_user_id = saml2_auth.ava["uid"][0] + except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") raise SynapseError(400, "uid not in SAML2 response") + try: + mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - username = saml2_auth.ava["uid"][0] displayName = saml2_auth.ava.get("displayName", [None])[0] - return self._sso_auth_handler.on_successful_auth( - username, request, relay_state, user_display_name=displayName - ) + with (await self._mapping_lock.queue(self._auth_provider_id)): + # first of all, check if we already have a mapping for this user + logger.info( + "Looking for existing mapping for user %s:%s", + self._auth_provider_id, + remote_user_id, + ) + registered_user_id = await self._datastore.get_user_by_external_id( + self._auth_provider_id, remote_user_id + ) + if registered_user_id is not None: + logger.info("Found existing mapping %s", registered_user_id) + return registered_user_id + + # backwards-compatibility hack: see if there is an existing user with a + # suitable mapping from the uid + if ( + self._grandfathered_mxid_source_attribute + and self._grandfathered_mxid_source_attribute in saml2_auth.ava + ): + attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] + user_id = UserID( + map_username_to_mxid_localpart(attrval), self._hostname + ).to_string() + logger.info( + "Looking for existing account based on mapped %s %s", + self._grandfathered_mxid_source_attribute, + user_id, + ) + + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + registered_user_id = list(users.keys())[0] + logger.info("Grandfathering mapping to %s", registered_user_id) + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id + + # figure out a new mxid for this user + base_mxid_localpart = self._mxid_mapper(mxid_source) + + suffix = 0 + while True: + localpart = base_mxid_localpart + (str(suffix) if suffix else "") + if not await self._datastore.get_users_by_id_case_insensitive( + UserID(localpart, self._hostname).to_string() + ): + break + suffix += 1 + logger.info("Allocating mxid for new user with localpart %s", localpart) + + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=displayName + ) + await self._datastore.record_user_external_id( + self._auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime -- cgit 1.4.1