diff options
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r-- | tests/handlers/test_oidc.py | 519 |
1 files changed, 290 insertions, 229 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index a308c46da9..ad20400b1d 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,32 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from typing import Optional from urllib.parse import parse_qs, urlparse -from mock import Mock, patch +from mock import ANY, Mock, patch -import attr import pymacaroons -from twisted.python.failure import Failure -from twisted.web._newclient import ResponseDone - -from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.sso import MappingException +from synapse.server import HomeServer from synapse.types import UserID +from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config +try: + import authlib # noqa: F401 -@attr.s -class FakeResponse: - code = attr.ib() - body = attr.ib() - phrase = attr.ib() - - def deliverBody(self, protocol): - protocol.dataReceived(self.body) - protocol.connectionLost(Failure(ResponseDone())) + HAS_OIDC = True +except ImportError: + HAS_OIDC = False # These are a few constants that are used as config parameters in the tests. @@ -46,7 +40,7 @@ ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" CLIENT_SECRET = "test-client-secret" BASE_URL = "https://synapse/" -CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] AUTHORIZATION_ENDPOINT = ISSUER + "authorize" @@ -64,17 +58,14 @@ COMMON_CONFIG = { } -# The cookie name and path don't really matter, just that it has to be coherent -# between the callback & redirect handlers. -COOKIE_NAME = b"oidc_session" -COOKIE_PATH = "/_synapse/oidc" - - -class TestMappingProvider(OidcMappingProvider): +class TestMappingProvider: @staticmethod def parse_config(config): return + def __init__(self, config): + pass + def get_remote_user_id(self, userinfo): return userinfo["sub"] @@ -97,16 +88,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -def simple_async_mock(return_value=None, raises=None): - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - async def get_json(url): # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: @@ -127,6 +108,9 @@ async def get_json(url): class OidcHandlerTestCase(HomeserverTestCase): + if not HAS_OIDC: + skip = "requires OIDC" + def default_config(self): config = super().default_config() config["public_baseurl"] = BASE_URL @@ -155,6 +139,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) self.handler = hs.get_oidc_handler() + self.provider = self.handler._providers["oidc"] sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) @@ -166,27 +151,29 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def metadata_edit(self, values): - return patch.dict(self.handler._provider_metadata, values) + return patch.dict(self.provider._provider_metadata, values) def assertRenderedError(self, error, error_description=None): + self.render_error.assert_called_once() args = self.render_error.call_args[0] self.assertEqual(args[1], error) if error_description is not None: self.assertEqual(args[2], error_description) # Reset the render_error mock self.render_error.reset_mock() + return args def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" - self.assertEqual(self.handler._callback_url, CALLBACK_URL) - self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID) - self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) + self.assertEqual(self.provider._callback_url, CALLBACK_URL) + self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) + self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {"discover": True}}) def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid - metadata = self.get_success(self.handler.load_metadata()) + metadata = self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.assertEqual(metadata.issuer, ISSUER) @@ -198,47 +185,47 @@ class OidcHandlerTestCase(HomeserverTestCase): # subsequent calls should be cached self.http_client.reset_mock() - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" - self.get_success(self.handler.load_metadata()) + self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" - jwks = self.get_success(self.handler.load_jwks()) + jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) self.assertEqual(jwks, {"keys": []}) # subsequent calls should be cached… self.http_client.reset_mock() - self.get_success(self.handler.load_jwks()) + self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_not_called() # …unless forced self.http_client.reset_mock() - self.get_success(self.handler.load_jwks(force=True)) + self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing with self.metadata_edit({"jwks_uri": None}): - self.get_failure(self.handler.load_jwks(force=True), RuntimeError) + self.get_failure(self.provider.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used - self.handler._scopes = [] # not asking the openid scope + self.provider._scopes = [] # not asking the openid scope self.http_client.get_json.reset_mock() - jwks = self.get_success(self.handler.load_jwks(force=True)) + jwks = self.get_success(self.provider.load_jwks(force=True)) self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) @override_config({"oidc_config": COMMON_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" - h = self.handler + h = self.provider # Default test config does not throw h._validate_metadata() @@ -317,13 +304,13 @@ class OidcHandlerTestCase(HomeserverTestCase): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw - self.handler._validate_metadata() + self.provider._validate_metadata() def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["addCookie"]) url = self.get_success( - self.handler.handle_redirect_request(req, b"http://client/redirect") + self.provider.handle_redirect_request(req, b"http://client/redirect") ) url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) @@ -347,14 +334,21 @@ class OidcHandlerTestCase(HomeserverTestCase): # For some reason, call.args does not work with python3.5 args = calls[0][0] kwargs = calls[0][1] - self.assertEqual(args[0], COOKIE_NAME) - self.assertEqual(kwargs["path"], COOKIE_PATH) + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + self.assertEqual(args[0], b"oidc_session") + self.assertEqual(kwargs["path"], "/_synapse/client/oidc") cookie = args[1] macaroon = pymacaroons.Macaroon.deserialize(cookie) - state = self.handler._get_value_from_macaroon(macaroon, "state") - nonce = self.handler._get_value_from_macaroon(macaroon, "nonce") - redirect = self.handler._get_value_from_macaroon( + state = self.handler._token_generator._get_value_from_macaroon( + macaroon, "state" + ) + nonce = self.handler._token_generator._get_value_from_macaroon( + macaroon, "nonce" + ) + redirect = self.handler._token_generator._get_value_from_macaroon( macaroon, "client_redirect_url" ) @@ -384,31 +378,29 @@ class OidcHandlerTestCase(HomeserverTestCase): - when the userinfo fetching fails - when the code exchange fails """ + + # ensure that we are correctly testing the fallback when "get_extra_attributes" + # is not implemented. + mapping_provider = self.provider._user_mapping_provider + with self.assertRaises(AttributeError): + _ = mapping_provider.get_extra_attributes + token = { "type": "bearer", "id_token": "id_token", "access_token": "access_token", } + username = "bar" userinfo = { "sub": "foo", - "preferred_username": "bar", + "username": username, } - user_id = "@foo:domain.org" - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) + expected_user_id = "@%s:%s" % (username, self.hs.hostname) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() code = "code" state = "state" @@ -416,74 +408,61 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - request.getCookie.return_value = self.handler._generate_oidc_session_token( - state=state, - nonce=nonce, - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, + session = self._generate_oidc_session_token(state, nonce, client_redirect_url) + request = _build_callback_request( + code, state, session, user_agent=user_agent, ip_address=ip_address ) - request.args = {} - request.args[b"code"] = [code.encode("utf-8")] - request.args[b"state"] = [state.encode("utf-8")] - - request.getClientIP.return_value = ip_address - request.get_user_agent.return_value = user_agent - self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, None, new_user=True ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address - ) - self.handler._fetch_userinfo.assert_not_called() + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) + self.provider._fetch_userinfo.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors - self.handler._map_userinfo_to_user = simple_async_mock( - raises=MappingException() - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("mapping_error") - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + with patch.object( + self.provider, + "_remote_id_from_userinfo", + new=Mock(side_effect=MappingException()), + ): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") # Handle ID token errors - self.handler._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - self.handler._auth_handler.complete_sso_login.reset_mock() - self.handler._exchange_code.reset_mock() - self.handler._parse_id_token.reset_mock() - self.handler._map_userinfo_to_user.reset_mock() - self.handler._fetch_userinfo.reset_mock() + auth_handler.complete_sso_login.reset_mock() + self.provider._exchange_code.reset_mock() + self.provider._parse_id_token.reset_mock() + self.provider._fetch_userinfo.reset_mock() # With userinfo fetching - self.handler._scopes = [] # do not ask the "openid" scope + self.provider._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, - ) - self.handler._exchange_code.assert_called_once_with(code) - self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, None, new_user=False ) - self.handler._fetch_userinfo.assert_called_once_with(token) + self.provider._exchange_code.assert_called_once_with(code) + self.provider._parse_id_token.assert_not_called() + self.provider._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() # Handle userinfo fetching error - self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure - self.handler._exchange_code = simple_async_mock( + from synapse.handlers.oidc_handler import OidcError + + self.provider._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) @@ -513,11 +492,8 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("invalid_session") # Mismatching session - session = self.handler._generate_oidc_session_token( - state="state", - nonce="nonce", - client_redirect_url="http://client/redirect", - ui_auth_session_id=None, + session = self._generate_oidc_session_token( + state="state", nonce="nonce", client_redirect_url="http://client/redirect", ) request.args = {} request.args[b"state"] = [b"mismatching state"] @@ -541,7 +517,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) ) code = "code" - ret = self.get_success(self.handler._exchange_code(code)) + ret = self.get_success(self.provider._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) @@ -563,7 +539,9 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "foo", "error_description": "bar"}', ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + from synapse.handlers.oidc_handler import OidcError + + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "foo") self.assertEqual(exc.value.error_description, "bar") @@ -573,7 +551,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, phrase=b"Internal Server Error", body=b"Not JSON", ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body @@ -585,14 +563,14 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field @@ -601,7 +579,7 @@ class OidcHandlerTestCase(HomeserverTestCase): code=200, phrase=b"OK", body=b'{"error": "some_error"}', ) ) - exc = self.get_failure(self.handler._exchange_code(code), OidcError) + exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "some_error") @override_config( @@ -624,72 +602,56 @@ class OidcHandlerTestCase(HomeserverTestCase): } userinfo = { "sub": "foo", + "username": "foo", "phone": "1234567", } - user_id = "@foo:domain.org" - self.handler._exchange_code = simple_async_mock(return_value=token) - self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) + self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() state = "state" client_redirect_url = "http://client/redirect" - request.getCookie.return_value = self.handler._generate_oidc_session_token( - state=state, - nonce="nonce", - client_redirect_url=client_redirect_url, - ui_auth_session_id=None, + session = self._generate_oidc_session_token( + state=state, nonce="nonce", client_redirect_url=client_redirect_url, ) - - request.args = {} - request.args[b"code"] = [b"code"] - request.args[b"state"] = [state.encode("utf-8")] - - request.getClientIP.return_value = "10.0.0.1" - request.get_user_agent.return_value = "Browser" + request = _build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {"phone": "1234567"}, + auth_handler.complete_sso_login.assert_called_once_with( + "@foo:test", + request, + client_redirect_url, + {"phone": "1234567"}, + new_user=True, ) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + userinfo = { "sub": "test_user", "username": "test_user", } - # The token doesn't matter with the default user mapping provider. - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", ANY, ANY, None, new_user=True ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user_2:test", ANY, ANY, None, new_user=True ) - self.assertEqual(mxid, "@test_user_2:test") + auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken store = self.hs.get_datastore() @@ -698,14 +660,11 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) - self.assertEqual( - str(e.value), "Mapping provider does not support de-duplicating Matrix IDs", + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", + "Mapping provider does not support de-duplicating Matrix IDs", ) @override_config({"oidc_config": {"allow_existing_users": True}}) @@ -717,26 +676,26 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, None, new_user=False ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, None, new_user=False ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -747,13 +706,11 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, None, new_user=False ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -770,14 +727,11 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + args = self.assertRenderedError("mapping_error") self.assertTrue( - str(e.value).startswith( + args[2].startswith( "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:" ) ) @@ -788,28 +742,17 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_called_once_with( + "@TEST_USER_2:test", ANY, ANY, None, new_user=False ) - self.assertEqual(mxid, "@TEST_USER_2:test") def test_map_userinfo_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" - userinfo = { - "sub": "test2", - "username": "föö", - } - token = {} - - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success( + _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"}) ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( { @@ -822,6 +765,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -830,14 +776,13 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) - ) + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", ANY, ANY, None, new_user=True + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -853,12 +798,128 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + + def test_empty_localpart(self): + """Attempts to map onto an empty localpart should be rejected.""" + userinfo = { + "sub": "tester", + "username": "", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.username }}"} + } + } + } + ) + def test_null_localpart(self): + """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" + userinfo = { + "sub": "tester", + "username": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + def _generate_oidc_session_token( + self, + state: str, + nonce: str, + client_redirect_url: str, + ui_auth_session_id: Optional[str] = None, + ) -> str: + from synapse.handlers.oidc_handler import OidcSessionData + + return self.handler._token_generator.generate_oidc_session_token( + state=state, + session_data=OidcSessionData( + idp_id="oidc", + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=ui_auth_session_id, + ), ) + + +async def _make_callback_with_userinfo( + hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" +) -> None: + """Mock up an OIDC callback with the given userinfo dict + + We'll pull out the OIDC handler from the homeserver, stub out a couple of methods, + and poke in the userinfo dict as if it were the response to an OIDC userinfo call. + + Args: + hs: the HomeServer impl to send the callback to. + userinfo: the OIDC userinfo dict + client_redirect_url: the URL to redirect to on success. + """ + from synapse.handlers.oidc_handler import OidcSessionData + + handler = hs.get_oidc_handler() + provider = handler._providers["oidc"] + provider._exchange_code = simple_async_mock(return_value={}) + provider._parse_id_token = simple_async_mock(return_value=userinfo) + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + + state = "state" + session = handler._token_generator.generate_oidc_session_token( + state=state, + session_data=OidcSessionData( + idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, + ), + ) + request = _build_callback_request("code", state, session) + + await handler.handle_oidc_callback(request) + + +def _build_callback_request( + code: str, + state: str, + session: str, + user_agent: str = "Browser", + ip_address: str = "10.0.0.1", +): + """Builds a fake SynapseRequest to mock the browser callback + + Returns a Mock object which looks like the SynapseRequest we get from a browser + after SSO (before we return to the client) + + Args: + code: the authorization code which would have been returned by the OIDC + provider + state: the "state" param which would have been passed around in the + query param. Should be the same as was embedded in the session in + _build_oidc_session. + session: the "session" which would have been passed around in the cookie. + user_agent: the user-agent to present + ip_address: the IP address to pretend the request came from + """ + request = Mock( + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "getHeader", + ] + ) + + request.getCookie.return_value = session + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + request.getClientIP.return_value = ip_address + return request |