summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_oidc.py181
1 files changed, 157 insertions, 24 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 02d4b2de0d..5e9c9c2e88 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+import os
 from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
@@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
 JWKS_URI = ISSUER + ".well-known/jwks.json"
 
 # config for common cases
-COMMON_CONFIG = {
+DEFAULT_CONFIG = {
+    "enabled": True,
+    "client_id": CLIENT_ID,
+    "client_secret": CLIENT_SECRET,
+    "issuer": ISSUER,
+    "scopes": SCOPES,
+    "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+}
+
+# extends the default config with explicit OAuth2 endpoints instead of using discovery
+EXPLICIT_ENDPOINT_CONFIG = {
+    **DEFAULT_CONFIG,
     "discover": False,
     "authorization_endpoint": AUTHORIZATION_ENDPOINT,
     "token_endpoint": TOKEN_ENDPOINT,
@@ -107,6 +119,32 @@ async def get_json(url):
         return {"keys": []}
 
 
+def _key_file_path() -> str:
+    """path to a file containing the private half of a test key"""
+
+    # this key was generated with:
+    #   openssl ecparam -name prime256v1 -genkey -noout |
+    #       openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
+    #
+    # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
+    # that's what Apple use, and we want to be sure that we work with Apple's keys.
+    #
+    # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
+    # keys using ASN.1. Both are then typically formatted using PEM, which says: use the
+    # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
+    # really need to care about any of that.)
+    return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
+
+
+def _public_key_file_path() -> str:
+    """path to a file containing the public half of a test key"""
+    # this was generated with:
+    #    openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
+    #
+    # See above about where oidc_test_key.p8 came from
+    return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
+
+
 class OidcHandlerTestCase(HomeserverTestCase):
     if not HAS_OIDC:
         skip = "requires OIDC"
@@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
-        oidc_config = {
-            "enabled": True,
-            "client_id": CLIENT_ID,
-            "client_secret": CLIENT_SECRET,
-            "issuer": ISSUER,
-            "scopes": SCOPES,
-            "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
-        }
-
-        # Update this config with what's in the default config so that
-        # override_config works as expected.
-        oidc_config.update(config.get("oidc_config", {}))
-        config["oidc_config"] = oidc_config
-
         return config
 
     def make_homeserver(self, reactor, clock):
@@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error.reset_mock()
         return args
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
-    @override_config({"oidc_config": {"discover": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
@@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
-    @override_config({"oidc_config": COMMON_CONFIG})
+    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
-    @override_config({"oidc_config": COMMON_CONFIG})
+    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
@@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
         h = self.provider
@@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             # Shouldn't raise with a valid userinfo, even without jwks
             force_load_metadata()
 
-    @override_config({"oidc_config": {"skip_verification": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
     def test_skip_verification(self):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
             get_awaitable_result(self.provider.load_metadata())
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["cookies"])
@@ -368,6 +395,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(params["nonce"], [nonce])
         self.assertEqual(redirect, "http://client/redirect")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_error(self):
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
@@ -379,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_client", "some description")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback(self):
         """Code callback works and display errors if something went wrong.
 
@@ -480,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
         request = Mock(spec=["args", "getCookie", "cookies"])
@@ -522,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
-    @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+    @override_config(
+        {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
+    )
     def test_exchange_code(self):
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
@@ -607,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                "enabled": True,
+                "client_id": CLIENT_ID,
+                "issuer": ISSUER,
+                "client_auth_method": "client_secret_post",
+                "client_secret_jwt_key": {
+                    "key_file": _key_file_path(),
+                    "jwt_header": {"alg": "ES256", "kid": "ABC789"},
+                    "jwt_payload": {"iss": "DEFGHI"},
+                },
+            }
+        }
+    )
+    def test_exchange_code_jwt_key(self):
+        """Test that code exchange works with a JWK client secret."""
+        from authlib.jose import jwt
+
+        token = {"type": "bearer"}
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+            )
+        )
+        code = "code"
+
+        # advance the clock a bit before we start, so we aren't working with zero
+        # timestamps.
+        self.reactor.advance(1000)
+        start_time = self.reactor.seconds()
+        ret = self.get_success(self.provider._exchange_code(code))
+
+        self.assertEqual(ret, token)
+
+        # the request should have hit the token endpoint
+        kwargs = self.http_client.request.call_args[1]
+        self.assertEqual(kwargs["method"], "POST")
+        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+        # the client secret provided to the should be a jwt which can be checked with
+        # the public key
+        args = parse_qs(kwargs["data"].decode("utf-8"))
+        secret = args["client_secret"][0]
+        with open(_public_key_file_path()) as f:
+            key = f.read()
+        claims = jwt.decode(secret, key)
+        self.assertEqual(claims.header["kid"], "ABC789")
+        self.assertEqual(claims["aud"], ISSUER)
+        self.assertEqual(claims["iss"], "DEFGHI")
+        self.assertEqual(claims["sub"], CLIENT_ID)
+        self.assertEqual(claims["iat"], start_time)
+        self.assertGreater(claims["exp"], start_time)
+
+        # check the rest of the POSTed data
+        self.assertEqual(args["grant_type"], ["authorization_code"])
+        self.assertEqual(args["code"], [code])
+        self.assertEqual(args["client_id"], [CLIENT_ID])
+        self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+    @override_config(
+        {
+            "oidc_config": {
+                "enabled": True,
+                "client_id": CLIENT_ID,
+                "issuer": ISSUER,
+                "client_auth_method": "none",
+            }
+        }
+    )
+    def test_exchange_code_no_auth(self):
+        """Test that code exchange works with no client secret."""
+        token = {"type": "bearer"}
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+            )
+        )
+        code = "code"
+        ret = self.get_success(self.provider._exchange_code(code))
+
+        self.assertEqual(ret, token)
+
+        # the request should have hit the token endpoint
+        kwargs = self.http_client.request.call_args[1]
+        self.assertEqual(kwargs["method"], "POST")
+        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+        # check the POSTed data
+        args = parse_qs(kwargs["data"].decode("utf-8"))
+        self.assertEqual(args["grant_type"], ["authorization_code"])
+        self.assertEqual(args["code"], [code])
+        self.assertEqual(args["client_id"], [CLIENT_ID])
+        self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+    @override_config(
+        {
+            "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "module": __name__ + ".TestMappingProviderExtra"
-                }
+                },
             }
         }
     )
@@ -652,6 +780,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             new_user=True,
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         auth_handler = self.hs.get_auth_handler()
@@ -692,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
-    @override_config({"oidc_config": {"allow_existing_users": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
     def test_map_userinfo_to_existing_user(self):
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastore()
@@ -772,6 +901,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
         self.get_success(
@@ -782,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "module": __name__ + ".TestMappingProviderFailures"
-                }
+                },
             }
         }
     )
@@ -829,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_empty_localpart(self):
         """Attempts to map onto an empty localpart should be rejected."""
         userinfo = {
@@ -841,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "config": {"localpart_template": "{{ user.username }}"}
-                }
+                },
             }
         }
     )