From ee382025b0c264701fc320133912e9fece40b021 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 17 Nov 2020 09:46:23 -0500 Subject: Abstract shared SSO code. (#8765) De-duplicates code between the SAML and OIDC implementations. --- synapse/handlers/saml_handler.py | 77 ++++++++++++---------------------------- 1 file changed, 23 insertions(+), 54 deletions(-) (limited to 'synapse/handlers/saml_handler.py') diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index fd6c5e9ea8..aee772239a 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -24,7 +24,8 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.saml2_config import SamlAttributeRequirement -from synapse.http.server import respond_with_html +from synapse.handlers._base import BaseHandler +from synapse.handlers.sso import MappingException from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -42,10 +43,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class MappingException(Exception): - """Used to catch errors when mapping the SAML2 response to a user.""" - - @attr.s(slots=True) class Saml2SessionData: """Data we track about SAML2 sessions""" @@ -57,17 +54,13 @@ class Saml2SessionData: ui_auth_session_id = attr.ib(type=Optional[str], default=None) -class SamlHandler: +class SamlHandler(BaseHandler): def __init__(self, hs: "synapse.server.HomeServer"): - self.hs = hs + super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() - self._clock = hs.get_clock() - self._datastore = hs.get_datastore() - self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute @@ -88,26 +81,9 @@ class SamlHandler: self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] # a lock on the mappings - self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) - - def _render_error( - self, request, error: str, error_description: Optional[str] = None - ) -> None: - """Render the error template and respond to the request with it. + self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock) - This is used to show errors to the user. The template of this page can - be found under `synapse/res/templates/sso_error.html`. - - Args: - request: The incoming request from the browser. - We'll respond with an HTML page describing the error. - error: A technical identifier for this error. - error_description: A human-readable description of the error. - """ - html = self._error_template.render( - error=error, error_description=error_description - ) - respond_with_html(request, 400, html) + self._sso_handler = hs.get_sso_handler() def handle_redirect_request( self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None @@ -130,7 +106,7 @@ class SamlHandler: # Since SAML sessions timeout it is useful to log when they were created. logger.info("Initiating a new SAML session: %s" % (reqid,)) - now = self._clock.time_msec() + now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( creation_time=now, ui_auth_session_id=ui_auth_session_id, ) @@ -171,12 +147,12 @@ class SamlHandler: # in the (user-visible) exception message, so let's log the exception here # so we can track down the session IDs later. logger.warning(str(e)) - self._render_error( + self._sso_handler.render_error( request, "unsolicited_response", "Unexpected SAML2 login." ) return except Exception as e: - self._render_error( + self._sso_handler.render_error( request, "invalid_response", "Unable to parse SAML2 response: %s." % (e,), @@ -184,7 +160,7 @@ class SamlHandler: return if saml2_auth.not_signed: - self._render_error( + self._sso_handler.render_error( request, "unsigned_respond", "SAML2 response was not signed." ) return @@ -210,7 +186,7 @@ class SamlHandler: # attributes. for requirement in self._saml2_attribute_requirements: if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._render_error( + self._sso_handler.render_error( request, "unauthorised", "You are not authorised to log in here." ) return @@ -226,7 +202,7 @@ class SamlHandler: ) except MappingException as e: logger.exception("Could not map user") - self._render_error(request, "mapping_error", str(e)) + self._sso_handler.render_error(request, "mapping_error", str(e)) return # Complete the interactive auth session or the login. @@ -274,17 +250,11 @@ class SamlHandler: with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user - logger.info( - "Looking for existing mapping for user %s:%s", - self._auth_provider_id, - remote_user_id, + previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id( + self._auth_provider_id, remote_user_id, ) - registered_user_id = await self._datastore.get_user_by_external_id( - self._auth_provider_id, remote_user_id - ) - if registered_user_id is not None: - logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id + if previously_registered_user_id: + return previously_registered_user_id # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid @@ -294,7 +264,7 @@ class SamlHandler: ): attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] user_id = UserID( - map_username_to_mxid_localpart(attrval), self._hostname + map_username_to_mxid_localpart(attrval), self.server_name ).to_string() logger.info( "Looking for existing account based on mapped %s %s", @@ -302,11 +272,11 @@ class SamlHandler: user_id, ) - users = await self._datastore.get_users_by_id_case_insensitive(user_id) + users = await self.store.get_users_by_id_case_insensitive(user_id) if users: registered_user_id = list(users.keys())[0] logger.info("Grandfathering mapping to %s", registered_user_id) - await self._datastore.record_user_external_id( + await self.store.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) return registered_user_id @@ -335,8 +305,8 @@ class SamlHandler: emails = attribute_dict.get("emails", []) # Check if this mxid already exists - if not await self._datastore.get_users_by_id_case_insensitive( - UserID(localpart, self._hostname).to_string() + if not await self.store.get_users_by_id_case_insensitive( + UserID(localpart, self.server_name).to_string() ): # This mxid is free break @@ -348,7 +318,6 @@ class SamlHandler: ) logger.info("Mapped SAML user to local part %s", localpart) - registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=displayname, @@ -356,13 +325,13 @@ class SamlHandler: user_agent_ips=(user_agent, ip_address), ) - await self._datastore.record_user_external_id( + await self.store.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) return registered_user_id def expire_sessions(self): - expire_before = self._clock.time_msec() - self._saml2_session_lifetime + expire_before = self.clock.time_msec() - self._saml2_session_lifetime to_expire = set() for reqid, data in self._outstanding_requests_dict.items(): if data.creation_time < expire_before: -- cgit 1.5.1 From 53a6f5ddf0c6bf2a8c8c3b757fb54a0c7755daf7 Mon Sep 17 00:00:00 2001 From: Ben Banfield-Zanin Date: Thu, 19 Nov 2020 14:57:13 +0000 Subject: SAML: Allow specifying the IdP entityid to use. (#8630) If the SAML metadata includes multiple IdPs it is necessary to specify which IdP to redirect users to for authentication. --- changelog.d/8630.feature | 1 + docs/sample_config.yaml | 8 ++++++++ synapse/config/saml2_config.py | 10 ++++++++++ synapse/handlers/saml_handler.py | 3 ++- 4 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8630.feature (limited to 'synapse/handlers/saml_handler.py') diff --git a/changelog.d/8630.feature b/changelog.d/8630.feature new file mode 100644 index 0000000000..706051f131 --- /dev/null +++ b/changelog.d/8630.feature @@ -0,0 +1 @@ +Allow specification of the SAML IdP if the metadata returns multiple IdPs. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index bedc147770..52a1d8b853 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1674,6 +1674,14 @@ saml2_config: # - attribute: department # value: "sales" + # If the metadata XML contains multiple IdP entities then the `idp_entityid` + # option must be set to the entity to redirect users to. + # + # Most deployments only have a single IdP entity and so should omit this + # option. + # + #idp_entityid: 'https://our_idp/entityid' + # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index f233854941..c1b8e98ae0 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -90,6 +90,8 @@ class SAML2Config(Config): "grandfathered_mxid_source_attribute", "uid" ) + self.saml2_idp_entityid = saml2_config.get("idp_entityid", None) + # user_mapping_provider may be None if the key is present but has no value ump_dict = saml2_config.get("user_mapping_provider") or {} @@ -383,6 +385,14 @@ class SAML2Config(Config): # value: "staff" # - attribute: department # value: "sales" + + # If the metadata XML contains multiple IdP entities then the `idp_entityid` + # option must be set to the entity to redirect users to. + # + # Most deployments only have a single IdP entity and so should omit this + # option. + # + #idp_entityid: 'https://our_idp/entityid' """ % { "config_dir_path": config_dir_path } diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index aee772239a..9bf430b656 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -58,6 +58,7 @@ class SamlHandler(BaseHandler): def __init__(self, hs: "synapse.server.HomeServer"): super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) + self._saml_idp_entityid = hs.config.saml2_idp_entityid self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -100,7 +101,7 @@ class SamlHandler(BaseHandler): URL to redirect to """ reqid, info = self._saml_client.prepare_for_authenticate( - relay_state=client_redirect_url + entityid=self._saml_idp_entityid, relay_state=client_redirect_url ) # Since SAML sessions timeout it is useful to log when they were created. -- cgit 1.5.1 From 79bfe966e08a2212cc2fae2b00f5efb2c2185543 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Nov 2020 14:25:17 -0500 Subject: Improve error checking for OIDC/SAML mapping providers (#8774) Checks that the localpart returned by mapping providers for SAML and OIDC are valid before registering new users. Extends the OIDC tests for existing users and invalid data. --- UPGRADE.rst | 30 ++++++++++++++ changelog.d/8774.misc | 1 + docs/sso_mapping_providers.md | 9 +++- synapse/handlers/oidc_handler.py | 25 ++++++++--- synapse/handlers/saml_handler.py | 6 +++ synapse/types.py | 6 +-- tests/handlers/test_oidc.py | 89 +++++++++++++++++++++++++++++++--------- 7 files changed, 137 insertions(+), 29 deletions(-) create mode 100644 changelog.d/8774.misc (limited to 'synapse/handlers/saml_handler.py') diff --git a/UPGRADE.rst b/UPGRADE.rst index 7c19cf2a70..4de1bb5841 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,36 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.24.0 +==================== + +Custom OpenID Connect mapping provider breaking change +------------------------------------------------------ + +This release allows the OpenID Connect mapping provider to perform normalisation +of the localpart of the Matrix ID. This allows for the mapping provider to +specify different algorithms, instead of the [default way](https://matrix.org/docs/spec/appendices#mapping-from-other-character-sets). + +If your Synapse configuration uses a custom mapping provider +(`oidc_config.user_mapping_provider.module` is specified and not equal to +`synapse.handlers.oidc_handler.JinjaOidcMappingProvider`) then you *must* ensure +that `map_user_attributes` of the mapping provider performs some normalisation +of the `localpart` returned. To match previous behaviour you can use the +`map_username_to_mxid_localpart` function provided by Synapse. An example is +shown below: + +.. code-block:: python + + from synapse.types import map_username_to_mxid_localpart + + class MyMappingProvider: + def map_user_attributes(self, userinfo, token): + # ... your custom logic ... + sso_user_id = ... + localpart = map_username_to_mxid_localpart(sso_user_id) + + return {"localpart": localpart} + Upgrading to v1.23.0 ==================== diff --git a/changelog.d/8774.misc b/changelog.d/8774.misc new file mode 100644 index 0000000000..57cca8fee5 --- /dev/null +++ b/changelog.d/8774.misc @@ -0,0 +1 @@ +Add additional error checking for OpenID Connect and SAML mapping providers. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 32b06aa2c5..707dd73978 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -15,8 +15,15 @@ where SAML mapping providers come into play. SSO mapping providers are currently supported for OpenID and SAML SSO configurations. Please see the details below for how to implement your own. +It is the responsibility of the mapping provider to normalise the SSO attributes +and map them to a valid Matrix ID. The +[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers) +has some information about what is considered valid. Alternately an easy way to +ensure it is valid is to use a Synapse utility function: +`synapse.types.map_username_to_mxid_localpart`. + External mapping providers are provided to Synapse in the form of an external -Python module. You can retrieve this module from [PyPi](https://pypi.org) or elsewhere, +Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere, but it must be importable via Synapse (e.g. it must be in the same virtualenv as Synapse). The Synapse config is then modified to point to the mapping provider (and optionally provide additional configuration for it). diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index be8562d47b..4bfd8d5617 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -38,7 +38,12 @@ from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart +from synapse.types import ( + JsonDict, + UserID, + contains_invalid_mxid_characters, + map_username_to_mxid_localpart, +) from synapse.util import json_decoder if TYPE_CHECKING: @@ -885,10 +890,12 @@ class OidcHandler(BaseHandler): "Retrieved user attributes from user mapping provider: %r", attributes ) - if not attributes["localpart"]: - raise MappingException("localpart is empty") - - localpart = map_username_to_mxid_localpart(attributes["localpart"]) + localpart = attributes["localpart"] + if not localpart: + raise MappingException( + "Error parsing OIDC response: OIDC mapping provider plugin " + "did not return a localpart value" + ) user_id = UserID(localpart, self.server_name).to_string() users = await self.store.get_users_by_id_case_insensitive(user_id) @@ -908,6 +915,11 @@ class OidcHandler(BaseHandler): # This mxid is taken raise MappingException("mxid '{}' is already taken".format(user_id)) else: + # Since the localpart is provided via a potentially untrusted module, + # ensure the MXID is valid before registering. + if contains_invalid_mxid_characters(localpart): + raise MappingException("localpart is invalid: %s" % (localpart,)) + # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( @@ -1076,6 +1088,9 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): ) -> UserAttribute: localpart = self._config.localpart_template.render(user=userinfo).strip() + # Ensure only valid characters are included in the MXID. + localpart = map_username_to_mxid_localpart(localpart) + display_name = None # type: Optional[str] if self._config.display_name_template is not None: display_name = self._config.display_name_template.render( diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 9bf430b656..5d9b555b13 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -31,6 +31,7 @@ from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi from synapse.types import ( UserID, + contains_invalid_mxid_characters, map_username_to_mxid_localpart, mxid_localpart_allowed_characters, ) @@ -318,6 +319,11 @@ class SamlHandler(BaseHandler): "Unable to generate a Matrix ID from the SAML response" ) + # Since the localpart is provided via a potentially untrusted module, + # ensure the MXID is valid before registering. + if contains_invalid_mxid_characters(localpart): + raise MappingException("localpart is invalid: %s" % (localpart,)) + logger.info("Mapped SAML user to local part %s", localpart) registered_user_id = await self._registration_handler.register_user( localpart=localpart, diff --git a/synapse/types.py b/synapse/types.py index 66bb5bac8d..3ab6bdbe06 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -317,14 +317,14 @@ mxid_localpart_allowed_characters = set( ) -def contains_invalid_mxid_characters(localpart): +def contains_invalid_mxid_characters(localpart: str) -> bool: """Check for characters not allowed in an mxid or groupid localpart Args: - localpart (basestring): the localpart to be checked + localpart: the localpart to be checked Returns: - bool: True if there are any naughty characters + True if there are any naughty characters """ return any(c not in mxid_localpart_allowed_characters for c in localpart) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 630e6da808..b4fa02acc4 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json from urllib.parse import parse_qs, urlparse @@ -24,12 +23,8 @@ import pymacaroons from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone -from synapse.handlers.oidc_handler import ( - MappingException, - OidcError, - OidcHandler, - OidcMappingProvider, -) +from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider +from synapse.handlers.sso import MappingException from synapse.types import UserID from tests.unittest import HomeserverTestCase, override_config @@ -132,14 +127,13 @@ class OidcHandlerTestCase(HomeserverTestCase): config = self.default_config() config["public_baseurl"] = BASE_URL - oidc_config = {} - oidc_config["enabled"] = True - oidc_config["client_id"] = CLIENT_ID - oidc_config["client_secret"] = CLIENT_SECRET - oidc_config["issuer"] = ISSUER - oidc_config["scopes"] = SCOPES - oidc_config["user_mapping_provider"] = { - "module": __name__ + ".TestMappingProvider", + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, } # Update this config with what's in the default config so that @@ -705,13 +699,13 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastore() - user4 = UserID.from_string("@test_user_4:test") + user = UserID.from_string("@test_user:test") self.get_success( - store.register_user(user_id=user4.to_string(), password_hash=None) + store.register_user(user_id=user.to_string(), password_hash=None) ) userinfo = { - "sub": "test4", - "username": "test_user_4", + "sub": "test", + "username": "test_user", } token = {} mxid = self.get_success( @@ -719,4 +713,59 @@ class OidcHandlerTestCase(HomeserverTestCase): userinfo, token, "user-agent", "10.10.10.10" ) ) - self.assertEqual(mxid, "@test_user_4:test") + self.assertEqual(mxid, "@test_user:test") + + # Register some non-exact matching cases. + user2 = UserID.from_string("@TEST_user_2:test") + self.get_success( + store.register_user(user_id=user2.to_string(), password_hash=None) + ) + user2_caps = UserID.from_string("@test_USER_2:test") + self.get_success( + store.register_user(user_id=user2_caps.to_string(), password_hash=None) + ) + + # Attempting to login without matching a name exactly is an error. + userinfo = { + "sub": "test2", + "username": "TEST_USER_2", + } + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertTrue( + str(e.value).startswith( + "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:" + ) + ) + + # Logging in when matching a name exactly should work. + user2 = UserID.from_string("@TEST_USER_2:test") + self.get_success( + store.register_user(user_id=user2.to_string(), password_hash=None) + ) + + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@TEST_USER_2:test") + + def test_map_userinfo_to_invalid_localpart(self): + """If the mapping provider generates an invalid localpart it should be rejected.""" + userinfo = { + "sub": "test2", + "username": "föö", + } + token = {} + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "localpart is invalid: föö") -- cgit 1.5.1 From 59a995f38dab8d83aee09ed7eb3858390ed53a65 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 23 Nov 2020 13:45:23 +0000 Subject: Improve logging of the mapping from SSO IDs to Matrix IDs. (#8773) --- changelog.d/8773.misc | 1 + synapse/handlers/saml_handler.py | 5 +++-- synapse/handlers/sso.py | 12 +++++++++--- 3 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8773.misc (limited to 'synapse/handlers/saml_handler.py') diff --git a/changelog.d/8773.misc b/changelog.d/8773.misc new file mode 100644 index 0000000000..62778ba410 --- /dev/null +++ b/changelog.d/8773.misc @@ -0,0 +1 @@ +Minor log line improvements for the SSO mapping code used to generate Matrix IDs from SSO IDs. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 5d9b555b13..f4e8cbeac8 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -268,7 +268,8 @@ class SamlHandler(BaseHandler): user_id = UserID( map_username_to_mxid_localpart(attrval), self.server_name ).to_string() - logger.info( + + logger.debug( "Looking for existing account based on mapped %s %s", self._grandfathered_mxid_source_attribute, user_id, @@ -324,7 +325,7 @@ class SamlHandler(BaseHandler): if contains_invalid_mxid_characters(localpart): raise MappingException("localpart is invalid: %s" % (localpart,)) - logger.info("Mapped SAML user to local part %s", localpart) + logger.debug("Mapped SAML user to local part %s", localpart) registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=displayname, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 9cb1866a71..cf7cb7754a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -71,19 +71,25 @@ class SsoHandler(BaseHandler): Returns: The mxid of a previously seen user. """ - # Check if we already have a mapping for this user. - logger.info( + logger.debug( "Looking for existing mapping for user %s:%s", auth_provider_id, remote_user_id, ) + + # Check if we already have a mapping for this user. previously_registered_user_id = await self.store.get_user_by_external_id( auth_provider_id, remote_user_id, ) # A match was found, return the user ID. if previously_registered_user_id is not None: - logger.info("Found existing mapping %s", previously_registered_user_id) + logger.info( + "Found existing mapping for IdP '%s' and remote_user_id '%s': %s", + auth_provider_id, + remote_user_id, + previously_registered_user_id, + ) return previously_registered_user_id # No match. -- cgit 1.5.1 From 6fde6aa9c02d35e0a908437ea49b275df9b58427 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Nov 2020 13:28:03 -0500 Subject: Properly report user-agent/IP during registration of SSO users. (#8784) This also expands type-hints to the SSO and registration code. Refactors the CAS code to more closely match OIDC/SAML. --- changelog.d/8784.misc | 1 + mypy.ini | 1 + synapse/handlers/cas_handler.py | 71 +++++++++---- synapse/handlers/oidc_handler.py | 2 +- synapse/handlers/register.py | 214 +++++++++++++++++++++------------------ synapse/handlers/saml_handler.py | 6 +- 6 files changed, 173 insertions(+), 122 deletions(-) create mode 100644 changelog.d/8784.misc (limited to 'synapse/handlers/saml_handler.py') diff --git a/changelog.d/8784.misc b/changelog.d/8784.misc new file mode 100644 index 0000000000..18a4263398 --- /dev/null +++ b/changelog.d/8784.misc @@ -0,0 +1 @@ +Fix a bug introduced in v1.20.0 where the user-agent and IP address reported during user registration for CAS, OpenID Connect, and SAML were of the wrong form. diff --git a/mypy.ini b/mypy.ini index fc9f8d8050..0cf7c93f45 100644 --- a/mypy.ini +++ b/mypy.ini @@ -37,6 +37,7 @@ files = synapse/handlers/presence.py, synapse/handlers/profile.py, synapse/handlers/read_marker.py, + synapse/handlers/register.py, synapse/handlers/room.py, synapse/handlers/room_member.py, synapse/handlers/room_member_worker.py, diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 048a3b3c0b..f4ea0a9767 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError @@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -31,10 +34,10 @@ class CasHandler: Utility class for to handle the response from a CAS SSO service. Args: - hs (synapse.server.HomeServer) + hs """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -200,27 +203,57 @@ class CasHandler: args["session"] = session username, user_display_name = await self._validate_ticket(ticket, args) - localpart = map_username_to_mxid_localpart(username) - user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = await self._auth_handler.check_user_exists(user_id) + # Pull out the user-agent and IP from the request. + user_agent = request.get_user_agent("") + ip_address = self.hs.get_ip_from_request(request) + + # Get the matrix ID from the CAS username. + user_id = await self._map_cas_user_to_matrix_user( + username, user_display_name, user_agent, ip_address + ) if session: await self._auth_handler.complete_sso_ui_auth( - registered_user_id, session, request, + user_id, session, request, ) - else: - if not registered_user_id: - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=user_display_name, - user_agent_ips=(user_agent, ip_address), - ) + # If this not a UI auth request than there must be a redirect URL. + assert client_redirect_url await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url + user_id, request, client_redirect_url ) + + async def _map_cas_user_to_matrix_user( + self, + remote_user_id: str, + display_name: Optional[str], + user_agent: str, + ip_address: str, + ) -> str: + """ + Given a CAS username, retrieve the user ID for it and possibly register the user. + + Args: + remote_user_id: The username from the CAS response. + display_name: The display name from the CAS response. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + + Returns: + The user ID associated with this response. + """ + + localpart = map_username_to_mxid_localpart(remote_user_id) + user_id = UserID(localpart, self._hostname).to_string() + registered_user_id = await self._auth_handler.check_user_exists(user_id) + + # If the user does not exist, register it. + if not registered_user_id: + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=display_name, + user_agent_ips=[(user_agent, ip_address)], + ) + + return registered_user_id diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4bfd8d5617..34de9109ea 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -925,7 +925,7 @@ class OidcHandler(BaseHandler): registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), + user_agent_ips=[(user_agent, ip_address)], ) await self.store.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 252f700786..0d85fd0868 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,10 +15,12 @@ """Contains functions for registering clients.""" import logging +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError +from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet @@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class RegistrationHandler(BaseHandler): - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() @@ -71,7 +71,10 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime async def check_username( - self, localpart, guest_access_token=None, assigned_user_id=None + self, + localpart: str, + guest_access_token: Optional[str] = None, + assigned_user_id: Optional[str] = None, ): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( @@ -140,39 +143,45 @@ class RegistrationHandler(BaseHandler): async def register_user( self, - localpart=None, - password_hash=None, - guest_access_token=None, - make_guest=False, - admin=False, - threepid=None, - user_type=None, - default_display_name=None, - address=None, - bind_emails=[], - by_admin=False, - user_agent_ips=None, - ): + localpart: Optional[str] = None, + password_hash: Optional[str] = None, + guest_access_token: Optional[str] = None, + make_guest: bool = False, + admin: bool = False, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + default_display_name: Optional[str] = None, + address: Optional[str] = None, + bind_emails: List[str] = [], + by_admin: bool = False, + user_agent_ips: Optional[List[Tuple[str, str]]] = None, + ) -> str: """Registers a new client on the server. Args: localpart: The local part of the user ID to register. If None, one will be generated. - password_hash (str|None): The hashed password to assign to this user so they can + password_hash: The hashed password to assign to this user so they can login again. This can be None which means they cannot login again via a password (e.g. the user is an application service user). - user_type (str|None): type of user. One of the values from + guest_access_token: The access token used when this was a guest + account. + make_guest: True if the the new user should be guest, + false to add a regular user account. + admin: True if the user should be registered as a server admin. + threepid: The threepid used for registering, if any. + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - default_display_name (unicode|None): if set, the new user's displayname + default_display_name: if set, the new user's displayname will be set to this. Defaults to 'localpart'. - address (str|None): the IP address used to perform the registration. - bind_emails (List[str]): list of emails to bind to this account. - by_admin (bool): True if this registration is being made via the + address: the IP address used to perform the registration. + bind_emails: list of emails to bind to this account. + by_admin: True if this registration is being made via the admin api, otherwise False. - user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. Returns: - str: user_id + The registere user_id. Raises: SynapseError if there was a problem registering. """ @@ -236,8 +245,10 @@ class RegistrationHandler(BaseHandler): else: # autogen a sequential user ID fail_count = 0 - user = None - while not user: + # If a default display name is not given, generate one. + generate_display_name = default_display_name is None + # This breaks on successful registration *or* errors after 10 failures. + while True: # Fail after being unable to find a suitable ID a few times if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") @@ -246,7 +257,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) - if default_display_name is None: + if generate_display_name: default_display_name = localpart try: await self.register_with_store( @@ -262,8 +273,6 @@ class RegistrationHandler(BaseHandler): break except SynapseError: # if user id is taken, just generate another - user = None - user_id = None fail_count += 1 if not self.hs.config.user_consent_at_registration: @@ -295,7 +304,7 @@ class RegistrationHandler(BaseHandler): return user_id - async def _create_and_join_rooms(self, user_id: str): + async def _create_and_join_rooms(self, user_id: str) -> None: """ Create the auto-join rooms and join or invite the user to them. @@ -379,7 +388,7 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _join_rooms(self, user_id: str): + async def _join_rooms(self, user_id: str) -> None: """ Join or invite the user to the auto-join rooms. @@ -425,6 +434,9 @@ class RegistrationHandler(BaseHandler): # Send the invite, if necessary. if requires_invite: + # If an invite is required, there must be a auto-join user ID. + assert self.hs.config.registration.auto_join_user_id + await room_member_handler.update_membership( requester=create_requester( self.hs.config.registration.auto_join_user_id, @@ -456,7 +468,7 @@ class RegistrationHandler(BaseHandler): except Exception as e: logger.error("Failed to join new user to %r: %r", r, e) - async def _auto_join_rooms(self, user_id: str): + async def _auto_join_rooms(self, user_id: str) -> None: """Automatically joins users to auto join rooms - creating the room in the first place if the user is the first to be created. @@ -479,16 +491,16 @@ class RegistrationHandler(BaseHandler): else: await self._join_rooms(user_id) - async def post_consent_actions(self, user_id): + async def post_consent_actions(self, user_id: str) -> None: """A series of registration actions that can only be carried out once consent has been granted Args: - user_id (str): The user to join + user_id: The user to join """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart, as_token): + async def appservice_register(self, user_localpart: str, as_token: str) -> str: user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -513,7 +525,9 @@ class RegistrationHandler(BaseHandler): ) return user_id - def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): + def check_user_id_not_appservice_exclusive( + self, user_id: str, allowed_appservice: Optional[ApplicationService] = None + ) -> None: # don't allow people to register the server notices mxid if self._server_notices_mxid is not None: if user_id == self._server_notices_mxid: @@ -537,12 +551,12 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address): + def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address Args: - address (str|None): the IP address used to perform the registration. If this is + address: the IP address used to perform the registration. If this is None, no ratelimiting will be performed. Raises: @@ -553,42 +567,39 @@ class RegistrationHandler(BaseHandler): self.ratelimiter.ratelimit(address) - def register_with_store( + async def register_with_store( self, - user_id, - password_hash=None, - was_guest=False, - make_guest=False, - appservice_id=None, - create_profile_with_displayname=None, - admin=False, - user_type=None, - address=None, - shadow_banned=False, - ): + user_id: str, + password_hash: Optional[str] = None, + was_guest: bool = False, + make_guest: bool = False, + appservice_id: Optional[str] = None, + create_profile_with_displayname: Optional[str] = None, + admin: bool = False, + user_type: Optional[str] = None, + address: Optional[str] = None, + shadow_banned: bool = False, + ) -> None: """Register user in the datastore. Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Optional. Whether this is a guest account being upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, + make_guest: True if the the new user should be guest, false to add a regular user account. - appservice_id (str|None): The ID of the appservice registering the user. - create_profile_with_displayname (unicode|None): Optionally create a + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, or None for a normal user. - address (str|None): the IP address used to perform the registration. - shadow_banned (bool): Whether to shadow-ban the user - - Returns: - Awaitable + address: the IP address used to perform the registration. + shadow_banned: Whether to shadow-ban the user """ if self.hs.config.worker_app: - return self._register_client( + await self._register_client( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -601,7 +612,7 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) else: - return self.store.register_user( + await self.store.register_user( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -614,22 +625,24 @@ class RegistrationHandler(BaseHandler): ) async def register_device( - self, user_id, device_id, initial_display_name, is_guest=False - ): + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + is_guest: bool = False, + ) -> Tuple[str, str]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate - a new one. - initial_display_name (str|None): An optional display name for the - device. - is_guest (bool): Whether this is a guest account + user_id: full canonical @user:id + device_id: The device ID to check, or None to generate a new one. + initial_display_name: An optional display name for the device. + is_guest: Whether this is a guest account Returns: - tuple[str, str]: Tuple of device ID and access token + Tuple of device ID and access token """ if self.hs.config.worker_app: @@ -649,7 +662,7 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime - device_id = await self.device_handler.check_device_registered( + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) if is_guest: @@ -659,20 +672,21 @@ class RegistrationHandler(BaseHandler): ) else: access_token = await self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms ) - return (device_id, access_token) + return (registered_device_id, access_token) - async def post_registration_actions(self, user_id, auth_result, access_token): + async def post_registration_actions( + self, user_id: str, auth_result: dict, access_token: Optional[str] + ) -> None: """A user has completed registration Args: - user_id (str): The user ID that consented - auth_result (dict): The authenticated credentials of the newly - registered user. - access_token (str|None): The access token of the newly logged in - device, or None if `inhibit_login` enabled. + user_id: The user ID that consented + auth_result: The authenticated credentials of the newly registered user. + access_token: The access token of the newly logged in device, or + None if `inhibit_login` enabled. """ if self.hs.config.worker_app: await self._post_registration_client( @@ -698,19 +712,20 @@ class RegistrationHandler(BaseHandler): if auth_result and LoginType.TERMS in auth_result: await self._on_user_consented(user_id, self.hs.config.user_consent_version) - async def _on_user_consented(self, user_id, consent_version): + async def _on_user_consented(self, user_id: str, consent_version: str) -> None: """A user consented to the terms on registration Args: - user_id (str): The user ID that consented. - consent_version (str): version of the policy the user has - consented to. + user_id: The user ID that consented. + consent_version: version of the policy the user has consented to. """ logger.info("%s has consented to the privacy policy", user_id) await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - async def _register_email_threepid(self, user_id, threepid, token): + async def _register_email_threepid( + self, user_id: str, threepid: dict, token: Optional[str] + ) -> None: """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the @@ -719,10 +734,9 @@ class RegistrationHandler(BaseHandler): Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.email.identity auth response - token (str|None): access_token for the user, or None if not logged - in. + user_id: id of user + threepid: m.login.email.identity auth response + token: access_token for the user, or None if not logged in. """ reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): @@ -748,6 +762,8 @@ class RegistrationHandler(BaseHandler): # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. user_tuple = await self.store.get_user_by_access_token(token) + # The token better still exist. + assert user_tuple token_id = user_tuple.token_id await self.pusher_pool.add_pusher( @@ -762,14 +778,14 @@ class RegistrationHandler(BaseHandler): data={}, ) - async def _register_msisdn_threepid(self, user_id, threepid): + async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None: """Add a phone number as a 3pid identifier Must be called on master. Args: - user_id (str): id of user - threepid (object): m.login.msisdn auth response + user_id: id of user + threepid: m.login.msisdn auth response """ try: assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index f4e8cbeac8..37ab42f050 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -39,7 +39,7 @@ from synapse.util.async_helpers import Linearizer from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class Saml2SessionData: class SamlHandler(BaseHandler): - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2_idp_entityid @@ -330,7 +330,7 @@ class SamlHandler(BaseHandler): localpart=localpart, default_display_name=displayname, bind_emails=emails, - user_agent_ips=(user_agent, ip_address), + user_agent_ips=[(user_agent, ip_address)], ) await self.store.record_user_external_id( -- cgit 1.5.1 From 4fd222ad704767e08c41a60690c4b499ed788b63 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 Nov 2020 10:04:22 -0500 Subject: Support trying multiple localparts for OpenID Connect. (#8801) Abstracts the SAML and OpenID Connect code which attempts to regenerate the localpart of a matrix ID if it is already in use. --- changelog.d/8801.feature | 1 + docs/sso_mapping_providers.md | 11 ++- synapse/handlers/oidc_handler.py | 120 +++++++++++++----------------- synapse/handlers/saml_handler.py | 91 +++++++---------------- synapse/handlers/sso.py | 155 ++++++++++++++++++++++++++++++++++++++- tests/handlers/test_oidc.py | 88 +++++++++++++++++++++- 6 files changed, 330 insertions(+), 136 deletions(-) create mode 100644 changelog.d/8801.feature (limited to 'synapse/handlers/saml_handler.py') diff --git a/changelog.d/8801.feature b/changelog.d/8801.feature new file mode 100644 index 0000000000..77f7fe4e5d --- /dev/null +++ b/changelog.d/8801.feature @@ -0,0 +1 @@ +Add support for re-trying generation of a localpart for OpenID Connect mapping providers. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 707dd73978..dee53b5d40 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods: information from. - This method must return a string, which is the unique identifier for the user. Commonly the ``sub`` claim of the response. -* `map_user_attributes(self, userinfo, token)` +* `map_user_attributes(self, userinfo, token, failures)` - This method must be async. - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user information from. - `token` - A dictionary which includes information necessary to make further requests to the OpenID provider. + - `failures` - An `int` that represents the amount of times the returned + mxid localpart mapping has failed. This should be used + to create a deduplicated mxid localpart which should be + returned instead. For example, if this method returns + `john.doe` as the value of `localpart` in the returned + dict, and that is already taken on the homeserver, this + method will be called again with the same parameters but + with failures=1. The method should then return a different + `localpart` value, such as `john.doe1`. - Returns a dictionary with two keys: - localpart: A required string, used to generate the Matrix ID. - displayname: An optional string, the display name for the user. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 34de9109ea..78c4e94a9d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode @@ -35,15 +36,10 @@ from twisted.web.client import readBody from synapse.config import ConfigError from synapse.handlers._base import BaseHandler -from synapse.handlers.sso import MappingException +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import ( - JsonDict, - UserID, - contains_invalid_mxid_characters, - map_username_to_mxid_localpart, -) +from synapse.types import JsonDict, map_username_to_mxid_localpart from synapse.util import json_decoder if TYPE_CHECKING: @@ -869,73 +865,51 @@ class OidcHandler(BaseHandler): # to be strings. remote_user_id = str(remote_user_id) - # first of all, check if we already have a mapping for this user - previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id( - self._auth_provider_id, remote_user_id, + # Older mapping providers don't accept the `failures` argument, so we + # try and detect support. + mapper_signature = inspect.signature( + self._user_mapping_provider.map_user_attributes ) - if previously_registered_user_id: - return previously_registered_user_id + supports_failures = "failures" in mapper_signature.parameters - # Otherwise, generate a new user. - try: - attributes = await self._user_mapping_provider.map_user_attributes( - userinfo, token - ) - except Exception as e: - raise MappingException( - "Could not extract user attributes from OIDC response: " + str(e) - ) + async def oidc_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Call the mapping provider to map the OIDC userinfo and token to user attributes. - logger.debug( - "Retrieved user attributes from user mapping provider: %r", attributes - ) + This is backwards compatibility for abstraction for the SSO handler. + """ + if supports_failures: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token, failures + ) + else: + # If the mapping provider does not support processing failures, + # do not continually generate the same Matrix ID since it will + # continue to already be in use. Note that the error raised is + # arbitrary and will get turned into a MappingException. + if failures: + raise RuntimeError( + "Mapping provider does not support de-duplicating Matrix IDs" + ) - localpart = attributes["localpart"] - if not localpart: - raise MappingException( - "Error parsing OIDC response: OIDC mapping provider plugin " - "did not return a localpart value" - ) + attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore + userinfo, token + ) - user_id = UserID(localpart, self.server_name).to_string() - users = await self.store.get_users_by_id_case_insensitive(user_id) - if users: - if self._allow_existing_users: - if len(users) == 1: - registered_user_id = next(iter(users)) - elif user_id in users: - registered_user_id = user_id - else: - raise MappingException( - "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( - user_id, list(users.keys()) - ) - ) - else: - # This mxid is taken - raise MappingException("mxid '{}' is already taken".format(user_id)) - else: - # Since the localpart is provided via a potentially untrusted module, - # ensure the MXID is valid before registering. - if contains_invalid_mxid_characters(localpart): - raise MappingException("localpart is invalid: %s" % (localpart,)) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=[(user_agent, ip_address)], - ) + return UserAttributes(**attributes) - await self.store.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id, + return await self._sso_handler.get_mxid_from_sso( + self._auth_provider_id, + remote_user_id, + user_agent, + ip_address, + oidc_response_to_user_attributes, + self._allow_existing_users, ) - return registered_user_id -UserAttribute = TypedDict( - "UserAttribute", {"localpart": str, "display_name": Optional[str]} +UserAttributeDict = TypedDict( + "UserAttributeDict", {"localpart": str, "display_name": Optional[str]} ) C = TypeVar("C") @@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]): raise NotImplementedError() async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider token: A dict with the tokens returned by the provider + failures: How many times a call to this function with this + UserInfo has resulted in a failure. Returns: A dict containing the ``localpart`` and (optionally) the ``display_name`` @@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return userinfo[self._config.subject_claim] async def map_user_attributes( - self, userinfo: UserInfo, token: Token - ) -> UserAttribute: + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: localpart = self._config.localpart_template.render(user=userinfo).strip() # Ensure only valid characters are included in the MXID. localpart = map_username_to_mxid_localpart(localpart) + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" + display_name = None # type: Optional[str] if self._config.display_name_template is not None: display_name = self._config.display_name_template.render( @@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if display_name == "": display_name = None - return UserAttribute(localpart=localpart, display_name=display_name) + return UserAttributeDict(localpart=localpart, display_name=display_name) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 37ab42f050..34db10ffe4 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.saml2_config import SamlAttributeRequirement from synapse.handlers._base import BaseHandler -from synapse.handlers.sso import MappingException +from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi from synapse.types import ( UserID, - contains_invalid_mxid_characters, map_username_to_mxid_localpart, mxid_localpart_allowed_characters, ) @@ -250,14 +249,26 @@ class SamlHandler(BaseHandler): "Failed to extract remote user id from SAML response" ) - with (await self._mapping_lock.queue(self._auth_provider_id)): - # first of all, check if we already have a mapping for this user - previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id( - self._auth_provider_id, remote_user_id, + async def saml_response_to_remapped_user_attributes( + failures: int, + ) -> UserAttributes: + """ + Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form. + + This is backwards compatibility for abstraction for the SSO handler. + """ + # Call the mapping provider. + result = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, failures, client_redirect_url + ) + # Remap some of the results. + return UserAttributes( + localpart=result.get("mxid_localpart"), + display_name=result.get("displayname"), + emails=result.get("emails"), ) - if previously_registered_user_id: - return previously_registered_user_id + with (await self._mapping_lock.queue(self._auth_provider_id)): # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid if ( @@ -284,59 +295,13 @@ class SamlHandler(BaseHandler): ) return registered_user_id - # Map saml response to user attributes using the configured mapping provider - for i in range(1000): - attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( - saml2_auth, i, client_redirect_url=client_redirect_url, - ) - - logger.debug( - "Retrieved SAML attributes from user mapping provider: %s " - "(attempt %d)", - attribute_dict, - i, - ) - - localpart = attribute_dict.get("mxid_localpart") - if not localpart: - raise MappingException( - "Error parsing SAML2 response: SAML mapping provider plugin " - "did not return a mxid_localpart value" - ) - - displayname = attribute_dict.get("displayname") - emails = attribute_dict.get("emails", []) - - # Check if this mxid already exists - if not await self.store.get_users_by_id_case_insensitive( - UserID(localpart, self.server_name).to_string() - ): - # This mxid is free - break - else: - # Unable to generate a username in 1000 iterations - # Break and return error to the user - raise MappingException( - "Unable to generate a Matrix ID from the SAML response" - ) - - # Since the localpart is provided via a potentially untrusted module, - # ensure the MXID is valid before registering. - if contains_invalid_mxid_characters(localpart): - raise MappingException("localpart is invalid: %s" % (localpart,)) - - logger.debug("Mapped SAML user to local part %s", localpart) - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=displayname, - bind_emails=emails, - user_agent_ips=[(user_agent, ip_address)], - ) - - await self.store.record_user_external_id( - self._auth_provider_id, remote_user_id, registered_user_id + return await self._sso_handler.get_mxid_from_sso( + self._auth_provider_id, + remote_user_id, + user_agent, + ip_address, + saml_response_to_remapped_user_attributes, ) - return registered_user_id def expire_sessions(self): expire_before = self.clock.time_msec() - self._saml2_session_lifetime @@ -451,11 +416,11 @@ class DefaultSamlMappingProvider: ) # Use the configured mapper for this mxid_source - base_mxid_localpart = self._mxid_mapper(mxid_source) + localpart = self._mxid_mapper(mxid_source) # Append suffix integer if last call to this function failed to produce - # a usable mxid - localpart = base_mxid_localpart + (str(failures) if failures else "") + # a usable mxid. + localpart += str(failures) if failures else "" # Retrieve the display name from the saml response # If displayname is None, the mxid_localpart will be used instead diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index cf7cb7754a..d963082210 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional + +import attr from synapse.handlers._base import BaseHandler from synapse.http.server import respond_with_html +from synapse.types import UserID, contains_invalid_mxid_characters if TYPE_CHECKING: from synapse.server import HomeServer @@ -29,9 +32,20 @@ class MappingException(Exception): """ +@attr.s +class UserAttributes: + localpart = attr.ib(type=str) + display_name = attr.ib(type=Optional[str], default=None) + emails = attr.ib(type=List[str], default=attr.Factory(list)) + + class SsoHandler(BaseHandler): + # The number of attempts to ask the mapping provider for when generating an MXID. + _MAP_USERNAME_RETRIES = 1000 + def __init__(self, hs: "HomeServer"): super().__init__(hs) + self._registration_handler = hs.get_registration_handler() self._error_template = hs.config.sso_error_template def render_error( @@ -94,3 +108,142 @@ class SsoHandler(BaseHandler): # No match. return None + + async def get_mxid_from_sso( + self, + auth_provider_id: str, + remote_user_id: str, + user_agent: str, + ip_address: str, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + allow_existing_users: bool = False, + ) -> str: + """ + Given an SSO ID, retrieve the user ID for it and possibly register the user. + + This first checks if the SSO ID has previously been linked to a matrix ID, + if it has that matrix ID is returned regardless of the current mapping + logic. + + The mapping function is called (potentially multiple times) to generate + a localpart for the user. + + If an unused localpart is generated, the user is registered from the + given user-agent and IP address and the SSO ID is linked to this matrix + ID for subsequent calls. + + If allow_existing_users is true the mapping function is only called once + and results in: + + 1. The use of a previously registered matrix ID. In this case, the + SSO ID is linked to the matrix ID. (Note it is possible that + other SSO IDs are linked to the same matrix ID.) + 2. An unused localpart, in which case the user is registered (as + discussed above). + 3. An error if the generated localpart matches multiple pre-existing + matrix IDs. Generally this should not happen. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + remote_user_id: The unique identifier from the SSO provider. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + sso_to_matrix_id_mapper: A callable to generate the user attributes. + The only parameter is an integer which represents the amount of + times the returned mxid localpart mapping has failed. + allow_existing_users: True if the localpart returned from the + mapping provider can be linked to an existing matrix ID. + + Returns: + The user ID associated with the SSO response. + + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. + + """ + # first of all, check if we already have a mapping for this user + previously_registered_user_id = await self.get_sso_user_by_remote_user_id( + auth_provider_id, remote_user_id, + ) + if previously_registered_user_id: + return previously_registered_user_id + + # Otherwise, generate a new user. + for i in range(self._MAP_USERNAME_RETRIES): + try: + attributes = await sso_to_matrix_id_mapper(i) + except Exception as e: + raise MappingException( + "Could not extract user attributes from SSO response: " + str(e) + ) + + logger.debug( + "Retrieved user attributes from user mapping provider: %r (attempt %d)", + attributes, + i, + ) + + if not attributes.localpart: + raise MappingException( + "Error parsing SSO response: SSO mapping provider plugin " + "did not return a localpart value" + ) + + # Check if this mxid already exists + user_id = UserID(attributes.localpart, self.server_name).to_string() + users = await self.store.get_users_by_id_case_insensitive(user_id) + # Note, if allow_existing_users is true then the loop is guaranteed + # to end on the first iteration: either by matching an existing user, + # raising an error, or registering a new user. See the docstring for + # more in-depth an explanation. + if users and allow_existing_users: + # If an existing matrix ID is returned, then use it. + if len(users) == 1: + previously_registered_user_id = next(iter(users)) + elif user_id in users: + previously_registered_user_id = user_id + else: + # Do not attempt to continue generating Matrix IDs. + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, users + ) + ) + + # Future logins should also match this user ID. + await self.store.record_user_external_id( + auth_provider_id, remote_user_id, previously_registered_user_id + ) + + return previously_registered_user_id + + elif not users: + # This mxid is free + break + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise MappingException( + "Unable to generate a Matrix ID from the SSO response" + ) + + # Since the localpart is provided via a potentially untrusted module, + # ensure the MXID is valid before registering. + if contains_invalid_mxid_characters(attributes.localpart): + raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) + + logger.debug("Mapped SSO user to local part %s", attributes.localpart) + registered_user_id = await self._registration_handler.register_user( + localpart=attributes.localpart, + default_display_name=attributes.display_name, + bind_emails=attributes.emails, + user_agent_ips=[(user_agent, ip_address)], + ) + + await self.store.record_user_external_id( + auth_provider_id, remote_user_id, registered_user_id + ) + return registered_user_id diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index b4fa02acc4..e880d32be6 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -89,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider): return {"phone": userinfo["phone"]} +class TestMappingProviderFailures(TestMappingProvider): + async def map_user_attributes(self, userinfo, token, failures): + return { + "localpart": userinfo["username"] + (str(failures) if failures else ""), + "display_name": None, + } + + def simple_async_mock(return_value=None, raises=None): # AsyncMock is not available in python3.5, this mimics part of its behaviour async def cb(*args, **kwargs): @@ -152,6 +160,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.render_error = Mock(return_value=None) self.handler._sso_handler.render_error = self.render_error + # Reduce the number of attempts when generating MXIDs. + self.handler._sso_handler._MAP_USERNAME_RETRIES = 3 + return hs def metadata_edit(self, values): @@ -693,7 +704,10 @@ class OidcHandlerTestCase(HomeserverTestCase): ), MappingException, ) - self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + self.assertEqual( + str(e.value), + "Could not extract user attributes from SSO response: Mapping provider does not support de-duplicating Matrix IDs", + ) @override_config({"oidc_config": {"allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): @@ -703,6 +717,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success( store.register_user(user_id=user.to_string(), password_hash=None) ) + + # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", @@ -715,6 +731,23 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.assertEqual(mxid, "@test_user:test") + # Note that a second SSO user can be mapped to the same Matrix ID. (This + # requires a unique sub, but something that maps to the same matrix ID, + # in this case we'll just use the same username. A more realistic example + # would be subs which are email addresses, and mapping from the localpart + # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.) + userinfo = { + "sub": "test1", + "username": "test_user", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") self.get_success( @@ -762,6 +795,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "föö", } token = {} + e = self.get_failure( self.handler._map_userinfo_to_user( userinfo, token, "user-agent", "10.10.10.10" @@ -769,3 +803,55 @@ class OidcHandlerTestCase(HomeserverTestCase): MappingException, ) self.assertEqual(str(e.value), "localpart is invalid: föö") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "module": __name__ + ".TestMappingProviderFailures" + } + } + } + ) + def test_map_userinfo_to_user_retries(self): + """The mapping provider can retry generating an MXID if the MXID is already in use.""" + store = self.hs.get_datastore() + self.get_success( + store.register_user(user_id="@test_user:test", password_hash=None) + ) + userinfo = { + "sub": "test", + "username": "test_user", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + # test_user is already taken, so test_user1 gets registered instead. + self.assertEqual(mxid, "@test_user1:test") + + # Register all of the potential users for a particular username. + self.get_success( + store.register_user(user_id="@tester:test", password_hash=None) + ) + for i in range(1, 3): + self.get_success( + store.register_user(user_id="@tester%d:test" % i, password_hash=None) + ) + + # Now attempt to map to a username, this will fail since all potential usernames are taken. + userinfo = { + "sub": "tester", + "username": "tester", + } + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual( + str(e.value), "Unable to generate a Matrix ID from the SSO response" + ) -- cgit 1.5.1