From 01333681bc3db22541b49c194f5121a5415731c6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 15 Dec 2020 20:56:10 +0000 Subject: Preparatory refactoring of the SamlHandlerTestCase (#8938) * move simple_async_mock to test_utils ... so that it can be re-used * Remove references to `SamlHandler._map_saml_response_to_user` from tests This method is going away, so we can no longer use it as a test point. Instead, factor out a higher-level method which takes a SAML object, and verify correct behaviour by mocking out `AuthHandler.complete_sso_login`. * changelog --- tests/handlers/test_saml.py | 132 +++++++++++++++++++++++++++++--------------- 1 file changed, 89 insertions(+), 43 deletions(-) (limited to 'tests/handlers/test_saml.py') diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index d21e5588ca..69927cf6be 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from mock import Mock + import attr from synapse.api.errors import RedirectException -from synapse.handlers.sso import MappingException +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -44,6 +48,8 @@ BASE_URL = "https://synapse/" @attr.s class FakeAuthnResponse: ava = attr.ib(type=dict) + assertions = attr.ib(type=list, factory=list) + in_response_to = attr.ib(type=Optional[str], default=None) class TestMappingProvider: @@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase): def test_map_saml_response_to_user(self): """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - # The redirect_url doesn't matter with the default user mapping provider. - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri" ) - self.assertEqual(mxid, "@test_user:test") @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): @@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase): store.register_user(user_id="@test_user:test", password_hash=None) ) + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. saml_response = FakeAuthnResponse( {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} ) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") def test_map_saml_response_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # mock out the error renderer too + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) - redirect_url = "" - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, "mapping_error", "localpart is invalid: föö" ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + auth_handler.complete_sso_login.assert_not_called() def test_map_saml_response_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + + # stub out the auth handler and error renderer + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + # register a user to occupy the first-choice MXID store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) + + # send the fake SAML response saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", request, "" + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular SAML username. self.get_success( @@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase): # Now attempt to map to a username, this will fail since all potential usernames are taken. saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + sso_handler.render_error.assert_called_once_with( + request, + "mapping_error", + "Unable to generate a Matrix ID from the SSO response", ) + auth_handler.complete_sso_login.assert_not_called() @override_config( { @@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase): } ) def test_map_saml_response_redirect(self): + """Test a mapping provider that raises a RedirectException""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" + request = _mock_request() e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), + self.handler._handle_authn_response(request, saml_response, ""), RedirectException, ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) -- cgit 1.5.1 From e1b8e37f936b115e2164d272333c9b15342e6f88 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 16 Dec 2020 20:01:53 +0000 Subject: Push login completion down into SsoHandler (#8941) This is another part of my work towards fixing #8876. It moves some of the logic currently in the SAML and OIDC handlers - in particular the call to `AuthHandler.complete_sso_login` down into the `SsoHandler`. --- changelog.d/8941.feature | 1 + synapse/handlers/oidc_handler.py | 62 +++++++++++++++++----------------------- synapse/handlers/saml_handler.py | 37 ++++++++---------------- synapse/handlers/sso.py | 58 +++++++++++++++++++++++-------------- tests/handlers/test_saml.py | 8 +++--- 5 files changed, 80 insertions(+), 86 deletions(-) create mode 100644 changelog.d/8941.feature (limited to 'tests/handlers/test_saml.py') diff --git a/changelog.d/8941.feature b/changelog.d/8941.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8941.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f626117f76..cbd11a1382 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -115,8 +115,6 @@ class OidcHandler(BaseHandler): self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key @@ -689,33 +687,14 @@ class OidcHandler(BaseHandler): # otherwise, it's a login - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user( - userinfo, token, user_agent, ip_address + await self._complete_oidc_login( + userinfo, token, request, client_redirect_url ) except MappingException as e: logger.exception("Could not map user") self._sso_handler.render_error(request, "mapping_error", str(e)) - return - - # Mapping providers might not have get_extra_attributes: only call this - # method if it exists. - extra_attributes = None - get_extra_attributes = getattr( - self._user_mapping_provider, "get_extra_attributes", None - ) - if get_extra_attributes: - extra_attributes = await get_extra_attributes(userinfo, token) - - # and finally complete the login - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_attributes - ) def _generate_oidc_session_token( self, @@ -838,10 +817,14 @@ class OidcHandler(BaseHandler): now = self.clock.time_msec() return now < expiry - async def _map_userinfo_to_user( - self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str - ) -> str: - """Maps a UserInfo object to a mxid. + async def _complete_oidc_login( + self, + userinfo: UserInfo, + token: Token, + request: SynapseRequest, + client_redirect_url: str, + ) -> None: + """Given a UserInfo response, complete the login flow UserInfo should have a claim that uniquely identifies users. This claim is usually `sub`, but can be configured with `oidc_config.subject_claim`. @@ -853,17 +836,16 @@ class OidcHandler(BaseHandler): If a user already exists with the mxid we've mapped and allow_existing_users is disabled, raise an exception. + Otherwise, render a redirect back to the client_redirect_url with a loginToken. + Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. + request: The request to respond to + client_redirect_url: The redirect URL passed in by the client. Raises: MappingException: if there was an error while mapping some properties - - Returns: - The mxid of the user """ try: remote_user_id = self._remote_id_from_userinfo(userinfo) @@ -931,13 +913,23 @@ class OidcHandler(BaseHandler): return None - return await self._sso_handler.get_mxid_from_sso( + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + + await self._sso_handler.complete_sso_login_request( self._auth_provider_id, remote_user_id, - user_agent, - ip_address, + request, + client_redirect_url, oidc_response_to_user_attributes, grandfather_existing_users, + extra_attributes, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 6001fe3e27..5fa7ab3f8b 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -58,8 +58,6 @@ class SamlHandler(BaseHandler): super().__init__(hs) self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2_idp_entityid - self._auth_handler = hs.get_auth_handler() - self._registration_handler = hs.get_registration_handler() self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( @@ -229,40 +227,29 @@ class SamlHandler(BaseHandler): ) return - # Pull out the user-agent and IP from the request. - user_agent = request.get_user_agent("") - ip_address = self.hs.get_ip_from_request(request) - # Call the mapper to register/login the user try: - user_id = await self._map_saml_response_to_user( - saml2_auth, relay_state, user_agent, ip_address - ) + await self._complete_saml_login(saml2_auth, request, relay_state) except MappingException as e: logger.exception("Could not map user") self._sso_handler.render_error(request, "mapping_error", str(e)) - return - await self._auth_handler.complete_sso_login(user_id, request, relay_state) - - async def _map_saml_response_to_user( + async def _complete_saml_login( self, saml2_auth: saml2.response.AuthnResponse, + request: SynapseRequest, client_redirect_url: str, - user_agent: str, - ip_address: str, - ) -> str: + ) -> None: """ - Given a SAML response, retrieve the user ID for it and possibly register the user. + Given a SAML response, complete the login flow + + Retrieves the remote user ID, registers the user if necessary, and serves + a redirect back to the client with a login-token. Args: saml2_auth: The parsed SAML2 response. + request: The request to respond to client_redirect_url: The redirect URL passed in by the client. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. - - Returns: - The user ID associated with this response. Raises: MappingException if there was a problem mapping the response to a user. @@ -318,11 +305,11 @@ class SamlHandler(BaseHandler): return None - return await self._sso_handler.get_mxid_from_sso( + await self._sso_handler.complete_sso_login_request( self._auth_provider_id, remote_user_id, - user_agent, - ip_address, + request, + client_redirect_url, saml_response_to_remapped_user_attributes, grandfather_existing_users, ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 112a7d5b2c..f054b66a53 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -21,7 +21,8 @@ from twisted.web.http import Request from synapse.api.errors import RedirectException from synapse.http.server import respond_with_html -from synapse.types import UserID, contains_invalid_mxid_characters +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer if TYPE_CHECKING: @@ -119,15 +120,16 @@ class SsoHandler: # No match. return None - async def get_mxid_from_sso( + async def complete_sso_login_request( self, auth_provider_id: str, remote_user_id: str, - user_agent: str, - ip_address: str, + request: SynapseRequest, + client_redirect_url: str, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]], - ) -> str: + extra_login_attributes: Optional[JsonDict] = None, + ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -146,12 +148,18 @@ class SsoHandler: given user-agent and IP address and the SSO ID is linked to this matrix ID for subsequent calls. + Finally, we generate a redirect to the supplied redirect uri, with a login token + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". + remote_user_id: The unique identifier from the SSO provider. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. + + request: The request to respond to + + client_redirect_url: The redirect URL passed in by the client. + sso_to_matrix_id_mapper: A callable to generate the user attributes. The only parameter is an integer which represents the amount of times the returned mxid localpart mapping has failed. @@ -163,12 +171,13 @@ class SsoHandler: to the user. RedirectException to redirect to an additional page (e.g. to prompt the user for more information). + grandfather_existing_users: A callable which can return an previously existing matrix ID. The SSO ID is then linked to the returned matrix ID. - Returns: - The user ID associated with the SSO response. + extra_login_attributes: An optional dictionary of extra + attributes to be provided to the client in the login response. Raises: MappingException if there was a problem mapping the response to a user. @@ -181,28 +190,33 @@ class SsoHandler: # interstitial pages. with await self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user - previously_registered_user_id = await self.get_sso_user_by_remote_user_id( + user_id = await self.get_sso_user_by_remote_user_id( auth_provider_id, remote_user_id, ) - if previously_registered_user_id: - return previously_registered_user_id # Check for grandfathering of users. - if grandfather_existing_users: - previously_registered_user_id = await grandfather_existing_users() - if previously_registered_user_id: + if not user_id and grandfather_existing_users: + user_id = await grandfather_existing_users() + if user_id: # Future logins should also match this user ID. await self._store.record_user_external_id( - auth_provider_id, remote_user_id, previously_registered_user_id + auth_provider_id, remote_user_id, user_id ) - return previously_registered_user_id # Otherwise, generate a new user. - attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) - user_id = await self._register_mapped_user( - attributes, auth_provider_id, remote_user_id, user_agent, ip_address, - ) - return user_id + if not user_id: + attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + user_id = await self._register_mapped_user( + attributes, + auth_provider_id, + remote_user_id, + request.get_user_agent(""), + request.getClientIP(), + ) + + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url, extra_login_attributes + ) async def _call_attribute_mapper( self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 69927cf6be..548038214b 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri" + "@test_user:test", request, "redirect_uri", None ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "" + "@test_user:test", request, "", None ) # Subsequent calls should map to the same mxid. @@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "" + "@test_user:test", request, "", None ) def test_map_saml_response_to_invalid_localpart(self): @@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", request, "" + "@test_user1:test", request, "", None ) auth_handler.complete_sso_login.reset_mock() -- cgit 1.5.1