diff options
Diffstat (limited to '')
-rw-r--r-- | tests/handlers/test_oidc.py | 158 |
1 files changed, 96 insertions, 62 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index b3dfa40d25..cf1de28fa9 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -24,7 +24,7 @@ from synapse.handlers.sso import MappingException from synapse.server import HomeServer from synapse.types import UserID -from tests.test_utils import FakeResponse, simple_async_mock +from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.unittest import HomeserverTestCase, override_config try: @@ -40,7 +40,7 @@ ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" CLIENT_SECRET = "test-client-secret" BASE_URL = "https://synapse/" -CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] AUTHORIZATION_ENDPOINT = ISSUER + "authorize" @@ -58,12 +58,6 @@ COMMON_CONFIG = { } -# The cookie name and path don't really matter, just that it has to be coherent -# between the callback & redirect handlers. -COOKIE_NAME = b"oidc_session" -COOKIE_PATH = "/_synapse/oidc" - - class TestMappingProvider: @staticmethod def parse_config(config): @@ -137,7 +131,6 @@ class OidcHandlerTestCase(HomeserverTestCase): return config def make_homeserver(self, reactor, clock): - self.http_client = Mock(spec=["get_json"]) self.http_client.get_json.side_effect = get_json self.http_client.user_agent = "Synapse Test" @@ -157,7 +150,15 @@ class OidcHandlerTestCase(HomeserverTestCase): return hs def metadata_edit(self, values): - return patch.dict(self.provider._provider_metadata, values) + """Modify the result that will be returned by the well-known query""" + + async def patched_get_json(uri): + res = await get_json(uri) + if uri == WELL_KNOWN: + res.update(values) + return res + + return patch.object(self.http_client, "get_json", patched_get_json) def assertRenderedError(self, error, error_description=None): self.render_error.assert_called_once() @@ -218,7 +219,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing - with self.metadata_edit({"jwks_uri": None}): + original = self.provider.load_metadata + + async def patched_load_metadata(): + m = (await original()).copy() + m.update({"jwks_uri": None}) + return m + + with patch.object(self.provider, "load_metadata", patched_load_metadata): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used @@ -228,55 +236,60 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) - @override_config({"oidc_config": COMMON_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" h = self.provider + def force_load_metadata(): + async def force_load(): + return await h.load_metadata(force=True) + + return get_awaitable_result(force_load()) + # Default test config does not throw - h._validate_metadata() + force_load_metadata() with self.metadata_edit({"issuer": None}): - self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + self.assertRaisesRegex(ValueError, "issuer", force_load_metadata) with self.metadata_edit({"issuer": "http://insecure/"}): - self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + self.assertRaisesRegex(ValueError, "issuer", force_load_metadata) with self.metadata_edit({"issuer": "https://invalid/?because=query"}): - self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata) + self.assertRaisesRegex(ValueError, "issuer", force_load_metadata) with self.metadata_edit({"authorization_endpoint": None}): self.assertRaisesRegex( - ValueError, "authorization_endpoint", h._validate_metadata + ValueError, "authorization_endpoint", force_load_metadata ) with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}): self.assertRaisesRegex( - ValueError, "authorization_endpoint", h._validate_metadata + ValueError, "authorization_endpoint", force_load_metadata ) with self.metadata_edit({"token_endpoint": None}): - self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata) with self.metadata_edit({"token_endpoint": "http://insecure/token"}): - self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata) + self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata) with self.metadata_edit({"jwks_uri": None}): - self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata) with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}): - self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata) + self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata) with self.metadata_edit({"response_types_supported": ["id_token"]}): self.assertRaisesRegex( - ValueError, "response_types_supported", h._validate_metadata + ValueError, "response_types_supported", force_load_metadata ) with self.metadata_edit( {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} ): # should not throw, as client_secret_basic is the default auth method - h._validate_metadata() + force_load_metadata() with self.metadata_edit( {"token_endpoint_auth_methods_supported": ["client_secret_post"]} @@ -284,7 +297,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRaisesRegex( ValueError, "token_endpoint_auth_methods_supported", - h._validate_metadata, + force_load_metadata, ) # Tests for configs that require the userinfo endpoint @@ -293,28 +306,30 @@ class OidcHandlerTestCase(HomeserverTestCase): h._user_profile_method = "userinfo_endpoint" self.assertTrue(h._uses_userinfo) - # Revert the profile method and do not request the "openid" scope. + # Revert the profile method and do not request the "openid" scope: this should + # mean that we check for a userinfo endpoint h._user_profile_method = "auto" h._scopes = [] self.assertTrue(h._uses_userinfo) - self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata) + with self.metadata_edit({"userinfo_endpoint": None}): + self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata) - with self.metadata_edit( - {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None} - ): - # Shouldn't raise with a valid userinfo, even without - h._validate_metadata() + with self.metadata_edit({"jwks_uri": None}): + # Shouldn't raise with a valid userinfo, even without jwks + force_load_metadata() @override_config({"oidc_config": {"skip_verification": True}}) def test_skip_verification(self): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw - self.provider._validate_metadata() + get_awaitable_result(self.provider.load_metadata()) def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" - req = Mock(spec=["addCookie"]) + req = Mock(spec=["cookies"]) + req.cookies = [] + url = self.get_success( self.provider.handle_redirect_request(req, b"http://client/redirect") ) @@ -333,16 +348,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(len(params["state"]), 1) self.assertEqual(len(params["nonce"]), 1) - # Check what is in the cookie - # note: python3.5 mock does not have the .called_once() method - calls = req.addCookie.call_args_list - self.assertEqual(len(calls), 1) # called once - # For some reason, call.args does not work with python3.5 - args = calls[0][0] - kwargs = calls[0][1] - self.assertEqual(args[0], COOKIE_NAME) - self.assertEqual(kwargs["path"], COOKIE_PATH) - cookie = args[1] + # Check what is in the cookies + self.assertEqual(len(req.cookies), 2) # two cookies + cookie_header = req.cookies[0] + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + parts = [p.strip() for p in cookie_header.split(b";")] + self.assertIn(b"Path=/_synapse/client/oidc", parts) + name, cookie = parts[0].split(b"=") + self.assertEqual(name, b"oidc_session") macaroon = pymacaroons.Macaroon.deserialize(cookie) state = self.handler._token_generator._get_value_from_macaroon( @@ -419,7 +434,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=True ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -450,7 +465,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=False ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() @@ -473,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase): def test_callback_session(self): """The callback verifies the session presence and validity""" - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock(spec=["args", "getCookie", "cookies"]) # Missing cookie request.args = {} @@ -496,7 +511,9 @@ class OidcHandlerTestCase(HomeserverTestCase): # Mismatching session session = self._generate_oidc_session_token( - state="state", nonce="nonce", client_redirect_url="http://client/redirect", + state="state", + nonce="nonce", + client_redirect_url="http://client/redirect", ) request.args = {} request.args[b"state"] = [b"mismatching state"] @@ -551,7 +568,9 @@ class OidcHandlerTestCase(HomeserverTestCase): # Internal server error with no JSON body self.http_client.request = simple_async_mock( return_value=FakeResponse( - code=500, phrase=b"Internal Server Error", body=b"Not JSON", + code=500, + phrase=b"Internal Server Error", + body=b"Not JSON", ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -571,7 +590,11 @@ class OidcHandlerTestCase(HomeserverTestCase): # 4xx error without "error" field self.http_client.request = simple_async_mock( - return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) + return_value=FakeResponse( + code=400, + phrase=b"Bad request", + body=b"{}", + ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) self.assertEqual(exc.value.error, "server_error") @@ -579,7 +602,9 @@ class OidcHandlerTestCase(HomeserverTestCase): # 2xx error with "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse( - code=200, phrase=b"OK", body=b'{"error": "some_error"}', + code=200, + phrase=b"OK", + body=b'{"error": "some_error"}', ) ) exc = self.get_failure(self.provider._exchange_code(code), OidcError) @@ -616,14 +641,20 @@ class OidcHandlerTestCase(HomeserverTestCase): state = "state" client_redirect_url = "http://client/redirect" session = self._generate_oidc_session_token( - state=state, nonce="nonce", client_redirect_url=client_redirect_url, + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, ) request = _build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - "@foo:test", request, client_redirect_url, {"phone": "1234567"}, + "@foo:test", + request, + client_redirect_url, + {"phone": "1234567"}, + new_user=True, ) def test_map_userinfo_to_user(self): @@ -637,7 +668,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", ANY, ANY, None, + "@test_user:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -648,7 +679,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", ANY, ANY, None, + "@test_user_2:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -685,14 +716,14 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -707,7 +738,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -743,7 +774,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", ANY, ANY, None, + "@TEST_USER_2:test", ANY, ANY, None, new_user=False ) def test_map_userinfo_to_invalid_localpart(self): @@ -779,7 +810,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", ANY, ANY, None, + "@test_user1:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -875,7 +906,9 @@ async def _make_callback_with_userinfo( session = handler._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, + idp_id="oidc", + nonce="nonce", + client_redirect_url=client_redirect_url, ), ) request = _build_callback_request("code", state, session) @@ -909,13 +942,14 @@ def _build_callback_request( spec=[ "args", "getCookie", - "addCookie", + "cookies", "requestHeaders", "getClientIP", "getHeader", ] ) + request.cookies = [] request.getCookie.return_value = session request.args = {} request.args[b"code"] = [code.encode("utf-8")] |