diff options
Diffstat (limited to '')
-rw-r--r-- | tests/handlers/test_oidc.py | 181 |
1 files changed, 157 insertions, 24 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 02d4b2de0d..5e9c9c2e88 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import os from urllib.parse import parse_qs, urlparse from mock import ANY, Mock, patch @@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration" JWKS_URI = ISSUER + ".well-known/jwks.json" # config for common cases -COMMON_CONFIG = { +DEFAULT_CONFIG = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, +} + +# extends the default config with explicit OAuth2 endpoints instead of using discovery +EXPLICIT_ENDPOINT_CONFIG = { + **DEFAULT_CONFIG, "discover": False, "authorization_endpoint": AUTHORIZATION_ENDPOINT, "token_endpoint": TOKEN_ENDPOINT, @@ -107,6 +119,32 @@ async def get_json(url): return {"keys": []} +def _key_file_path() -> str: + """path to a file containing the private half of a test key""" + + # this key was generated with: + # openssl ecparam -name prime256v1 -genkey -noout | + # openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8 + # + # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because + # that's what Apple use, and we want to be sure that we work with Apple's keys. + # + # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing + # keys using ASN.1. Both are then typically formatted using PEM, which says: use the + # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't + # really need to care about any of that.) + return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8") + + +def _public_key_file_path() -> str: + """path to a file containing the public half of a test key""" + # this was generated with: + # openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem + # + # See above about where oidc_test_key.p8 came from + return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem") + + class OidcHandlerTestCase(HomeserverTestCase): if not HAS_OIDC: skip = "requires OIDC" @@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase): def default_config(self): config = super().default_config() config["public_baseurl"] = BASE_URL - oidc_config = { - "enabled": True, - "client_id": CLIENT_ID, - "client_secret": CLIENT_SECRET, - "issuer": ISSUER, - "scopes": SCOPES, - "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"}, - } - - # Update this config with what's in the default config so that - # override_config works as expected. - oidc_config.update(config.get("oidc_config", {})) - config["oidc_config"] = oidc_config - return config def make_homeserver(self, reactor, clock): @@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.render_error.reset_mock() return args + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) - @override_config({"oidc_config": {"discover": True}}) + @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid @@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() - @override_config({"oidc_config": COMMON_CONFIG}) + @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() - @override_config({"oidc_config": COMMON_CONFIG}) + @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) @@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_validate_config(self): """Provider metadatas are extensively validated.""" h = self.provider @@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase): # Shouldn't raise with a valid userinfo, even without jwks force_load_metadata() - @override_config({"oidc_config": {"skip_verification": True}}) + @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) def test_skip_verification(self): """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw get_awaitable_result(self.provider.load_metadata()) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["cookies"]) @@ -368,6 +395,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(params["nonce"], [nonce]) self.assertEqual(redirect, "http://client/redirect") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_error(self): """Errors from the provider returned in the callback are displayed.""" request = Mock(args={}) @@ -379,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "some description") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback(self): """Code callback works and display errors if something went wrong. @@ -480,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_callback_session(self): """The callback verifies the session presence and validity""" request = Mock(spec=["args", "getCookie", "cookies"]) @@ -522,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") - @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) + @override_config( + {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} + ) def test_exchange_code(self): """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} @@ -607,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config( { "oidc_config": { + "enabled": True, + "client_id": CLIENT_ID, + "issuer": ISSUER, + "client_auth_method": "client_secret_post", + "client_secret_jwt_key": { + "key_file": _key_file_path(), + "jwt_header": {"alg": "ES256", "kid": "ABC789"}, + "jwt_payload": {"iss": "DEFGHI"}, + }, + } + } + ) + def test_exchange_code_jwt_key(self): + """Test that code exchange works with a JWK client secret.""" + from authlib.jose import jwt + + token = {"type": "bearer"} + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") + ) + ) + code = "code" + + # advance the clock a bit before we start, so we aren't working with zero + # timestamps. + self.reactor.advance(1000) + start_time = self.reactor.seconds() + ret = self.get_success(self.provider._exchange_code(code)) + + self.assertEqual(ret, token) + + # the request should have hit the token endpoint + kwargs = self.http_client.request.call_args[1] + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + # the client secret provided to the should be a jwt which can be checked with + # the public key + args = parse_qs(kwargs["data"].decode("utf-8")) + secret = args["client_secret"][0] + with open(_public_key_file_path()) as f: + key = f.read() + claims = jwt.decode(secret, key) + self.assertEqual(claims.header["kid"], "ABC789") + self.assertEqual(claims["aud"], ISSUER) + self.assertEqual(claims["iss"], "DEFGHI") + self.assertEqual(claims["sub"], CLIENT_ID) + self.assertEqual(claims["iat"], start_time) + self.assertGreater(claims["exp"], start_time) + + # check the rest of the POSTed data + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + @override_config( + { + "oidc_config": { + "enabled": True, + "client_id": CLIENT_ID, + "issuer": ISSUER, + "client_auth_method": "none", + } + } + ) + def test_exchange_code_no_auth(self): + """Test that code exchange works with no client secret.""" + token = {"type": "bearer"} + self.http_client.request = simple_async_mock( + return_value=FakeResponse( + code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8") + ) + ) + code = "code" + ret = self.get_success(self.provider._exchange_code(code)) + + self.assertEqual(ret, token) + + # the request should have hit the token endpoint + kwargs = self.http_client.request.call_args[1] + self.assertEqual(kwargs["method"], "POST") + self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT) + + # check the POSTed data + args = parse_qs(kwargs["data"].decode("utf-8")) + self.assertEqual(args["grant_type"], ["authorization_code"]) + self.assertEqual(args["code"], [code]) + self.assertEqual(args["client_id"], [CLIENT_ID]) + self.assertEqual(args["redirect_uri"], [CALLBACK_URL]) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "module": __name__ + ".TestMappingProviderExtra" - } + }, } } ) @@ -652,6 +780,7 @@ class OidcHandlerTestCase(HomeserverTestCase): new_user=True, ) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" auth_handler = self.hs.get_auth_handler() @@ -692,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "Mapping provider does not support de-duplicating Matrix IDs", ) - @override_config({"oidc_config": {"allow_existing_users": True}}) + @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) def test_map_userinfo_to_existing_user(self): """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastore() @@ -772,6 +901,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False ) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_map_userinfo_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" self.get_success( @@ -782,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config( { "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "module": __name__ + ".TestMappingProviderFailures" - } + }, } } ) @@ -829,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) + @override_config({"oidc_config": DEFAULT_CONFIG}) def test_empty_localpart(self): """Attempts to map onto an empty localpart should be rejected.""" userinfo = { @@ -841,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config( { "oidc_config": { + **DEFAULT_CONFIG, "user_mapping_provider": { "config": {"localpart_template": "{{ user.username }}"} - } + }, } } ) |