summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--MANIFEST.in7
-rw-r--r--changelog.d/9549.feature1
-rw-r--r--docs/openid.md42
-rw-r--r--docs/sample_config.yaml21
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/oidc_config.py89
-rw-r--r--synapse/handlers/oidc_handler.py101
-rw-r--r--tests/handlers/oidc_test_key.p85
-rw-r--r--tests/handlers/oidc_test_key.pub.pem4
-rw-r--r--tests/handlers/test_oidc.py181
11 files changed, 444 insertions, 47 deletions
diff --git a/MANIFEST.in b/MANIFEST.in
index 120ce5b776..25d1cb758e 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -20,9 +20,10 @@ recursive-include scripts *
 recursive-include scripts-dev *
 recursive-include synapse *.pyi
 recursive-include tests *.py
-include tests/http/ca.crt
-include tests/http/ca.key
-include tests/http/server.key
+recursive-include tests *.pem
+recursive-include tests *.p8
+recursive-include tests *.crt
+recursive-include tests *.key
 
 recursive-include synapse/res *
 recursive-include synapse/static *.css
diff --git a/changelog.d/9549.feature b/changelog.d/9549.feature
new file mode 100644
index 0000000000..709e61eced
--- /dev/null
+++ b/changelog.d/9549.feature
@@ -0,0 +1 @@
+Add support for generating JSON Web Tokens dynamically for use as OIDC client secrets.
diff --git a/docs/openid.md b/docs/openid.md
index 263bc9f6f8..01205d1220 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -386,7 +386,7 @@ oidc_providers:
       config:
         subject_claim: "id"
         localpart_template: "{{ user.login }}"
-        display_name_template: "{{ user.full_name }}" 
+        display_name_template: "{{ user.full_name }}"
 ```
 
 ### XWiki
@@ -401,8 +401,7 @@ oidc_providers:
     idp_name: "XWiki"
     issuer: "https://myxwikihost/xwiki/oidc/"
     client_id: "your-client-id" # TO BE FILLED
-    # Needed until https://github.com/matrix-org/synapse/issues/9212 is fixed
-    client_secret: "dontcare"
+    client_auth_method: none
     scopes: ["openid", "profile"]
     user_profile_method: "userinfo_endpoint"
     user_mapping_provider:
@@ -410,3 +409,40 @@ oidc_providers:
         localpart_template: "{{ user.preferred_username }}"
         display_name_template: "{{ user.name }}"
 ```
+
+## Apple
+
+Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account.
+
+You will need to create a new "Services ID" for SiWA, and create and download a
+private key with "SiWA" enabled.
+
+As well as the private key file, you will need:
+ * Client ID: the "identifier" you gave the "Services ID"
+ * Team ID: a 10-character ID associated with your developer account.
+ * Key ID: the 10-character identifier for the key.
+
+https://help.apple.com/developer-account/?lang=en#/dev77c875b7e has more
+documentation on setting up SiWA.
+
+The synapse config will look like this:
+
+```yaml
+  - idp_id: apple
+    idp_name: Apple
+    issuer: "https://appleid.apple.com"
+    client_id: "your-client-id" # Set to the "identifier" for your "ServicesID"
+    client_auth_method: "client_secret_post"
+    client_secret_jwt_key:
+      key_file: "/path/to/AuthKey_KEYIDCODE.p8"  # point to your key file
+      jwt_header:
+        alg: ES256
+        kid: "KEYIDCODE"   # Set to the 10-char Key ID
+      jwt_payload:
+        iss: TEAMIDCODE    # Set to the 10-char Team ID
+    scopes: ["name", "email", "openid"]
+    authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
+    user_mapping_provider:
+      config:
+        email_template: "{{ user.email }}"
+```
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index c95a4f5970..c32ee4a897 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1779,7 +1779,26 @@ saml2_config:
 #
 #   client_id: Required. oauth2 client id to use.
 #
-#   client_secret: Required. oauth2 client secret to use.
+#   client_secret: oauth2 client secret to use. May be omitted if
+#        client_secret_jwt_key is given, or if client_auth_method is 'none'.
+#
+#   client_secret_jwt_key: Alternative to client_secret: details of a key used
+#      to create a JSON Web Token to be used as an OAuth2 client secret. If
+#      given, must be a dictionary with the following properties:
+#
+#          key: a pem-encoded signing key. Must be a suitable key for the
+#              algorithm specified. Required unless 'key_file' is given.
+#
+#          key_file: the path to file containing a pem-encoded signing key file.
+#              Required unless 'key' is given.
+#
+#          jwt_header: a dictionary giving properties to include in the JWT
+#              header. Must include the key 'alg', giving the algorithm used to
+#              sign the JWT, such as "ES256", using the JWA identifiers in
+#              RFC7518.
+#
+#          jwt_payload: an optional dictionary giving properties to include in
+#              the JWT payload. Normally this should include an 'iss' key.
 #
 #   client_auth_method: auth method to use when exchanging the token. Valid
 #       values are 'client_secret_basic' (default), 'client_secret_post' and
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 4026966711..ba9cd63cf2 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -212,9 +212,8 @@ class Config:
 
     @classmethod
     def read_file(cls, file_path, config_name):
-        cls.check_file(file_path, config_name)
-        with open(file_path) as file_stream:
-            return file_stream.read()
+        """Deprecated: call read_file directly"""
+        return read_file(file_path, (config_name,))
 
     def read_template(self, filename: str) -> jinja2.Template:
         """Load a template file from disk.
@@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
         return self._get_instance(key)
 
 
-__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+def read_file(file_path: Any, config_path: Iterable[str]) -> str:
+    """Check the given file exists, and read it into a string
+
+    If it does not, emit an error indicating the problem
+
+    Args:
+        file_path: the file to be read
+        config_path: where in the configuration file_path came from, so that a useful
+           error can be emitted if it does not exist.
+    Returns:
+        content of the file.
+    Raises:
+        ConfigError if there is a problem reading the file.
+    """
+    if not isinstance(file_path, str):
+        raise ConfigError("%r is not a string", config_path)
+
+    try:
+        os.stat(file_path)
+        with open(file_path) as file_stream:
+            return file_stream.read()
+    except OSError as e:
+        raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
+
+
+__all__ = [
+    "Config",
+    "RootConfig",
+    "ShardedWorkerHandlingConfig",
+    "RoutableShardedWorkerHandlingConfig",
+    "read_file",
+]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index db16c86f50..e896fd34e2 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
 
 class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
     def get_instance(self, key: str) -> str: ...
+
+def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index a27594befc..7f5e449eb2 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 from collections import Counter
-from typing import Iterable, Optional, Tuple, Type
+from typing import Iterable, Mapping, Optional, Tuple, Type
 
 import attr
 
@@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
 from synapse.util.module_loader import load_module
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
-from ._base import Config, ConfigError
+from ._base import Config, ConfigError, read_file
 
 DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
 
@@ -97,7 +97,26 @@ class OIDCConfig(Config):
         #
         #   client_id: Required. oauth2 client id to use.
         #
-        #   client_secret: Required. oauth2 client secret to use.
+        #   client_secret: oauth2 client secret to use. May be omitted if
+        #        client_secret_jwt_key is given, or if client_auth_method is 'none'.
+        #
+        #   client_secret_jwt_key: Alternative to client_secret: details of a key used
+        #      to create a JSON Web Token to be used as an OAuth2 client secret. If
+        #      given, must be a dictionary with the following properties:
+        #
+        #          key: a pem-encoded signing key. Must be a suitable key for the
+        #              algorithm specified. Required unless 'key_file' is given.
+        #
+        #          key_file: the path to file containing a pem-encoded signing key file.
+        #              Required unless 'key' is given.
+        #
+        #          jwt_header: a dictionary giving properties to include in the JWT
+        #              header. Must include the key 'alg', giving the algorithm used to
+        #              sign the JWT, such as "ES256", using the JWA identifiers in
+        #              RFC7518.
+        #
+        #          jwt_payload: an optional dictionary giving properties to include in
+        #              the JWT payload. Normally this should include an 'iss' key.
         #
         #   client_auth_method: auth method to use when exchanging the token. Valid
         #       values are 'client_secret_basic' (default), 'client_secret_post' and
@@ -240,7 +259,7 @@ class OIDCConfig(Config):
 # jsonschema definition of the configuration settings for an oidc identity provider
 OIDC_PROVIDER_CONFIG_SCHEMA = {
     "type": "object",
-    "required": ["issuer", "client_id", "client_secret"],
+    "required": ["issuer", "client_id"],
     "properties": {
         "idp_id": {
             "type": "string",
@@ -262,6 +281,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
         "issuer": {"type": "string"},
         "client_id": {"type": "string"},
         "client_secret": {"type": "string"},
+        "client_secret_jwt_key": {
+            "type": "object",
+            "required": ["jwt_header"],
+            "oneOf": [
+                {"required": ["key"]},
+                {"required": ["key_file"]},
+            ],
+            "properties": {
+                "key": {"type": "string"},
+                "key_file": {"type": "string"},
+                "jwt_header": {
+                    "type": "object",
+                    "required": ["alg"],
+                    "properties": {
+                        "alg": {"type": "string"},
+                    },
+                    "additionalProperties": {"type": "string"},
+                },
+                "jwt_payload": {
+                    "type": "object",
+                    "additionalProperties": {"type": "string"},
+                },
+            },
+        },
         "client_auth_method": {
             "type": "string",
             # the following list is the same as the keys of
@@ -404,6 +447,20 @@ def _parse_oidc_config_dict(
                 "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
             ) from e
 
+    client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
+    client_secret_jwt_key = None  # type: Optional[OidcProviderClientSecretJwtKey]
+    if client_secret_jwt_key_config is not None:
+        keyfile = client_secret_jwt_key_config.get("key_file")
+        if keyfile:
+            key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
+        else:
+            key = client_secret_jwt_key_config["key"]
+        client_secret_jwt_key = OidcProviderClientSecretJwtKey(
+            key=key,
+            jwt_header=client_secret_jwt_key_config["jwt_header"],
+            jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
+        )
+
     return OidcProviderConfig(
         idp_id=idp_id,
         idp_name=oidc_config.get("idp_name", "OIDC"),
@@ -412,7 +469,8 @@ def _parse_oidc_config_dict(
         discover=oidc_config.get("discover", True),
         issuer=oidc_config["issuer"],
         client_id=oidc_config["client_id"],
-        client_secret=oidc_config["client_secret"],
+        client_secret=oidc_config.get("client_secret"),
+        client_secret_jwt_key=client_secret_jwt_key,
         client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
         scopes=oidc_config.get("scopes", ["openid"]),
         authorization_endpoint=oidc_config.get("authorization_endpoint"),
@@ -428,6 +486,18 @@ def _parse_oidc_config_dict(
 
 
 @attr.s(slots=True, frozen=True)
+class OidcProviderClientSecretJwtKey:
+    # a pem-encoded signing key
+    key = attr.ib(type=str)
+
+    # properties to include in the JWT header
+    jwt_header = attr.ib(type=Mapping[str, str])
+
+    # properties to include in the JWT payload.
+    jwt_payload = attr.ib(type=Mapping[str, str])
+
+
+@attr.s(slots=True, frozen=True)
 class OidcProviderConfig:
     # a unique identifier for this identity provider. Used in the 'user_external_ids'
     # table, as well as the query/path parameter used in the login protocol.
@@ -452,8 +522,13 @@ class OidcProviderConfig:
     # oauth2 client id to use
     client_id = attr.ib(type=str)
 
-    # oauth2 client secret to use
-    client_secret = attr.ib(type=str)
+    # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+    # a secret.
+    client_secret = attr.ib(type=Optional[str])
+
+    # key to use to construct a JWT to use as a client secret. May be `None` if
+    # `client_secret` is set.
+    client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
 
     # auth method to use when exchanging the token.
     # Valid values are 'client_secret_basic', 'client_secret_post' and
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index b4a74390cc..825fadb76f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,13 +15,13 @@
 # limitations under the License.
 import inspect
 import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
 from urllib.parse import urlencode
 
 import attr
 import pymacaroons
 from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken
+from authlib.jose import JsonWebToken, jwt
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
 from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@@ -35,12 +36,15 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.config.oidc_config import OidcProviderConfig
+from synapse.config.oidc_config import (
+    OidcProviderClientSecretJwtKey,
+    OidcProviderConfig,
+)
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
-from synapse.util import json_decoder
+from synapse.util import Clock, json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
@@ -276,9 +280,21 @@ class OidcProvider:
 
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
+
+        client_secret = None  # type: Union[None, str, JwtClientSecret]
+        if provider.client_secret:
+            client_secret = provider.client_secret
+        elif provider.client_secret_jwt_key:
+            client_secret = JwtClientSecret(
+                provider.client_secret_jwt_key,
+                provider.client_id,
+                provider.issuer,
+                hs.get_clock(),
+            )
+
         self._client_auth = ClientAuth(
             provider.client_id,
-            provider.client_secret,
+            client_secret,
             provider.client_auth_method,
         )  # type: ClientAuth
         self._client_auth_method = provider.client_auth_method
@@ -977,6 +993,81 @@ class OidcProvider:
         return str(remote_user_id)
 
 
+# number of seconds a newly-generated client secret should be valid for
+CLIENT_SECRET_VALIDITY_SECONDS = 3600
+
+# minimum remaining validity on a client secret before we should generate a new one
+CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
+
+
+class JwtClientSecret:
+    """A class which generates a new client secret on demand, based on a JWK
+
+    This implementation is designed to comply with the requirements for Apple Sign in:
+    https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
+
+    It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
+    but it's worth noting that we still put the generated secret in the "client_secret"
+    field (or rather, whereever client_auth_method puts it) rather than in a
+    client_assertion field in the body as that RFC seems to require.
+    """
+
+    def __init__(
+        self,
+        key: OidcProviderClientSecretJwtKey,
+        oauth_client_id: str,
+        oauth_issuer: str,
+        clock: Clock,
+    ):
+        self._key = key
+        self._oauth_client_id = oauth_client_id
+        self._oauth_issuer = oauth_issuer
+        self._clock = clock
+        self._cached_secret = b""
+        self._cached_secret_replacement_time = 0
+
+    def __str__(self):
+        # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
+        # encode_client_secret_basic, which calls "{}".format(secret), which ends up
+        # here.
+        return self._get_secret().decode("ascii")
+
+    def __bytes__(self):
+        # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
+        # encode_client_secret_post, which ends up here.
+        return self._get_secret()
+
+    def _get_secret(self) -> bytes:
+        now = self._clock.time()
+
+        # if we have enough validity on our existing secret, use it
+        if now < self._cached_secret_replacement_time:
+            return self._cached_secret
+
+        issued_at = int(now)
+        expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
+
+        # we copy the configured header because jwt.encode modifies it.
+        header = dict(self._key.jwt_header)
+
+        # see https://tools.ietf.org/html/rfc7523#section-3
+        payload = {
+            "sub": self._oauth_client_id,
+            "aud": self._oauth_issuer,
+            "iat": issued_at,
+            "exp": expires_at,
+            **self._key.jwt_payload,
+        }
+        logger.info(
+            "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
+        )
+        self._cached_secret = jwt.encode(header, payload, self._key.key)
+        self._cached_secret_replacement_time = (
+            expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
+        )
+        return self._cached_secret
+
+
 class OidcSessionTokenGenerator:
     """Methods for generating and checking OIDC Session cookies."""
 
diff --git a/tests/handlers/oidc_test_key.p8 b/tests/handlers/oidc_test_key.p8
new file mode 100644
index 0000000000..bb92976333
--- /dev/null
+++ b/tests/handlers/oidc_test_key.p8
@@ -0,0 +1,5 @@
+-----BEGIN PRIVATE KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
+Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
+P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
+-----END PRIVATE KEY-----
diff --git a/tests/handlers/oidc_test_key.pub.pem b/tests/handlers/oidc_test_key.pub.pem
new file mode 100644
index 0000000000..176d4a4b4b
--- /dev/null
+++ b/tests/handlers/oidc_test_key.pub.pem
@@ -0,0 +1,4 @@
+-----BEGIN PUBLIC KEY-----
+MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
+QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
+-----END PUBLIC KEY-----
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 02d4b2de0d..5e9c9c2e88 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+import os
 from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
@@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
 JWKS_URI = ISSUER + ".well-known/jwks.json"
 
 # config for common cases
-COMMON_CONFIG = {
+DEFAULT_CONFIG = {
+    "enabled": True,
+    "client_id": CLIENT_ID,
+    "client_secret": CLIENT_SECRET,
+    "issuer": ISSUER,
+    "scopes": SCOPES,
+    "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+}
+
+# extends the default config with explicit OAuth2 endpoints instead of using discovery
+EXPLICIT_ENDPOINT_CONFIG = {
+    **DEFAULT_CONFIG,
     "discover": False,
     "authorization_endpoint": AUTHORIZATION_ENDPOINT,
     "token_endpoint": TOKEN_ENDPOINT,
@@ -107,6 +119,32 @@ async def get_json(url):
         return {"keys": []}
 
 
+def _key_file_path() -> str:
+    """path to a file containing the private half of a test key"""
+
+    # this key was generated with:
+    #   openssl ecparam -name prime256v1 -genkey -noout |
+    #       openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
+    #
+    # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
+    # that's what Apple use, and we want to be sure that we work with Apple's keys.
+    #
+    # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
+    # keys using ASN.1. Both are then typically formatted using PEM, which says: use the
+    # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
+    # really need to care about any of that.)
+    return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
+
+
+def _public_key_file_path() -> str:
+    """path to a file containing the public half of a test key"""
+    # this was generated with:
+    #    openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
+    #
+    # See above about where oidc_test_key.p8 came from
+    return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
+
+
 class OidcHandlerTestCase(HomeserverTestCase):
     if not HAS_OIDC:
         skip = "requires OIDC"
@@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
-        oidc_config = {
-            "enabled": True,
-            "client_id": CLIENT_ID,
-            "client_secret": CLIENT_SECRET,
-            "issuer": ISSUER,
-            "scopes": SCOPES,
-            "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
-        }
-
-        # Update this config with what's in the default config so that
-        # override_config works as expected.
-        oidc_config.update(config.get("oidc_config", {}))
-        config["oidc_config"] = oidc_config
-
         return config
 
     def make_homeserver(self, reactor, clock):
@@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error.reset_mock()
         return args
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
-    @override_config({"oidc_config": {"discover": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
@@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
-    @override_config({"oidc_config": COMMON_CONFIG})
+    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
-    @override_config({"oidc_config": COMMON_CONFIG})
+    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
@@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
         h = self.provider
@@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             # Shouldn't raise with a valid userinfo, even without jwks
             force_load_metadata()
 
-    @override_config({"oidc_config": {"skip_verification": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
     def test_skip_verification(self):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
             get_awaitable_result(self.provider.load_metadata())
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["cookies"])
@@ -368,6 +395,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(params["nonce"], [nonce])
         self.assertEqual(redirect, "http://client/redirect")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_error(self):
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
@@ -379,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_client", "some description")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback(self):
         """Code callback works and display errors if something went wrong.
 
@@ -480,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
         request = Mock(spec=["args", "getCookie", "cookies"])
@@ -522,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
-    @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+    @override_config(
+        {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
+    )
     def test_exchange_code(self):
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
@@ -607,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                "enabled": True,
+                "client_id": CLIENT_ID,
+                "issuer": ISSUER,
+                "client_auth_method": "client_secret_post",
+                "client_secret_jwt_key": {
+                    "key_file": _key_file_path(),
+                    "jwt_header": {"alg": "ES256", "kid": "ABC789"},
+                    "jwt_payload": {"iss": "DEFGHI"},
+                },
+            }
+        }
+    )
+    def test_exchange_code_jwt_key(self):
+        """Test that code exchange works with a JWK client secret."""
+        from authlib.jose import jwt
+
+        token = {"type": "bearer"}
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+            )
+        )
+        code = "code"
+
+        # advance the clock a bit before we start, so we aren't working with zero
+        # timestamps.
+        self.reactor.advance(1000)
+        start_time = self.reactor.seconds()
+        ret = self.get_success(self.provider._exchange_code(code))
+
+        self.assertEqual(ret, token)
+
+        # the request should have hit the token endpoint
+        kwargs = self.http_client.request.call_args[1]
+        self.assertEqual(kwargs["method"], "POST")
+        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+        # the client secret provided to the should be a jwt which can be checked with
+        # the public key
+        args = parse_qs(kwargs["data"].decode("utf-8"))
+        secret = args["client_secret"][0]
+        with open(_public_key_file_path()) as f:
+            key = f.read()
+        claims = jwt.decode(secret, key)
+        self.assertEqual(claims.header["kid"], "ABC789")
+        self.assertEqual(claims["aud"], ISSUER)
+        self.assertEqual(claims["iss"], "DEFGHI")
+        self.assertEqual(claims["sub"], CLIENT_ID)
+        self.assertEqual(claims["iat"], start_time)
+        self.assertGreater(claims["exp"], start_time)
+
+        # check the rest of the POSTed data
+        self.assertEqual(args["grant_type"], ["authorization_code"])
+        self.assertEqual(args["code"], [code])
+        self.assertEqual(args["client_id"], [CLIENT_ID])
+        self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+    @override_config(
+        {
+            "oidc_config": {
+                "enabled": True,
+                "client_id": CLIENT_ID,
+                "issuer": ISSUER,
+                "client_auth_method": "none",
+            }
+        }
+    )
+    def test_exchange_code_no_auth(self):
+        """Test that code exchange works with no client secret."""
+        token = {"type": "bearer"}
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+            )
+        )
+        code = "code"
+        ret = self.get_success(self.provider._exchange_code(code))
+
+        self.assertEqual(ret, token)
+
+        # the request should have hit the token endpoint
+        kwargs = self.http_client.request.call_args[1]
+        self.assertEqual(kwargs["method"], "POST")
+        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+        # check the POSTed data
+        args = parse_qs(kwargs["data"].decode("utf-8"))
+        self.assertEqual(args["grant_type"], ["authorization_code"])
+        self.assertEqual(args["code"], [code])
+        self.assertEqual(args["client_id"], [CLIENT_ID])
+        self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+    @override_config(
+        {
+            "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "module": __name__ + ".TestMappingProviderExtra"
-                }
+                },
             }
         }
     )
@@ -652,6 +780,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             new_user=True,
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         auth_handler = self.hs.get_auth_handler()
@@ -692,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
-    @override_config({"oidc_config": {"allow_existing_users": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
     def test_map_userinfo_to_existing_user(self):
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastore()
@@ -772,6 +901,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
         self.get_success(
@@ -782,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "module": __name__ + ".TestMappingProviderFailures"
-                }
+                },
             }
         }
     )
@@ -829,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_empty_localpart(self):
         """Attempts to map onto an empty localpart should be rejected."""
         userinfo = {
@@ -841,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "config": {"localpart_template": "{{ user.username }}"}
-                }
+                },
             }
         }
     )