summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-09 15:03:37 +0000
committerGitHub <noreply@github.com>2021-03-09 15:03:37 +0000
commiteaada74075a4567c489fff6ae2206f2af8298fd4 (patch)
tree5b57aaa4ecea74f84d53d651d45490dd859286ad /synapse
parentRetry 5xx errors in federation client (#9567) (diff)
downloadsynapse-eaada74075a4567c489fff6ae2206f2af8298fd4.tar.xz
JWT OIDC secrets for Sign in with Apple (#9549)
Apple had to be special. They want a client secret which is generated from an EC key.

Fixes #9220. Also fixes #9212 while I'm here.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/oidc_config.py89
-rw-r--r--synapse/handlers/oidc_handler.py101
4 files changed, 214 insertions, 16 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 4026966711..ba9cd63cf2 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -212,9 +212,8 @@ class Config:
 
     @classmethod
     def read_file(cls, file_path, config_name):
-        cls.check_file(file_path, config_name)
-        with open(file_path) as file_stream:
-            return file_stream.read()
+        """Deprecated: call read_file directly"""
+        return read_file(file_path, (config_name,))
 
     def read_template(self, filename: str) -> jinja2.Template:
         """Load a template file from disk.
@@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
         return self._get_instance(key)
 
 
-__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+def read_file(file_path: Any, config_path: Iterable[str]) -> str:
+    """Check the given file exists, and read it into a string
+
+    If it does not, emit an error indicating the problem
+
+    Args:
+        file_path: the file to be read
+        config_path: where in the configuration file_path came from, so that a useful
+           error can be emitted if it does not exist.
+    Returns:
+        content of the file.
+    Raises:
+        ConfigError if there is a problem reading the file.
+    """
+    if not isinstance(file_path, str):
+        raise ConfigError("%r is not a string", config_path)
+
+    try:
+        os.stat(file_path)
+        with open(file_path) as file_stream:
+            return file_stream.read()
+    except OSError as e:
+        raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
+
+
+__all__ = [
+    "Config",
+    "RootConfig",
+    "ShardedWorkerHandlingConfig",
+    "RoutableShardedWorkerHandlingConfig",
+    "read_file",
+]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index db16c86f50..e896fd34e2 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
 
 class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
     def get_instance(self, key: str) -> str: ...
+
+def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index a27594befc..7f5e449eb2 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 from collections import Counter
-from typing import Iterable, Optional, Tuple, Type
+from typing import Iterable, Mapping, Optional, Tuple, Type
 
 import attr
 
@@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
 from synapse.util.module_loader import load_module
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
-from ._base import Config, ConfigError
+from ._base import Config, ConfigError, read_file
 
 DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
 
@@ -97,7 +97,26 @@ class OIDCConfig(Config):
         #
         #   client_id: Required. oauth2 client id to use.
         #
-        #   client_secret: Required. oauth2 client secret to use.
+        #   client_secret: oauth2 client secret to use. May be omitted if
+        #        client_secret_jwt_key is given, or if client_auth_method is 'none'.
+        #
+        #   client_secret_jwt_key: Alternative to client_secret: details of a key used
+        #      to create a JSON Web Token to be used as an OAuth2 client secret. If
+        #      given, must be a dictionary with the following properties:
+        #
+        #          key: a pem-encoded signing key. Must be a suitable key for the
+        #              algorithm specified. Required unless 'key_file' is given.
+        #
+        #          key_file: the path to file containing a pem-encoded signing key file.
+        #              Required unless 'key' is given.
+        #
+        #          jwt_header: a dictionary giving properties to include in the JWT
+        #              header. Must include the key 'alg', giving the algorithm used to
+        #              sign the JWT, such as "ES256", using the JWA identifiers in
+        #              RFC7518.
+        #
+        #          jwt_payload: an optional dictionary giving properties to include in
+        #              the JWT payload. Normally this should include an 'iss' key.
         #
         #   client_auth_method: auth method to use when exchanging the token. Valid
         #       values are 'client_secret_basic' (default), 'client_secret_post' and
@@ -240,7 +259,7 @@ class OIDCConfig(Config):
 # jsonschema definition of the configuration settings for an oidc identity provider
 OIDC_PROVIDER_CONFIG_SCHEMA = {
     "type": "object",
-    "required": ["issuer", "client_id", "client_secret"],
+    "required": ["issuer", "client_id"],
     "properties": {
         "idp_id": {
             "type": "string",
@@ -262,6 +281,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
         "issuer": {"type": "string"},
         "client_id": {"type": "string"},
         "client_secret": {"type": "string"},
+        "client_secret_jwt_key": {
+            "type": "object",
+            "required": ["jwt_header"],
+            "oneOf": [
+                {"required": ["key"]},
+                {"required": ["key_file"]},
+            ],
+            "properties": {
+                "key": {"type": "string"},
+                "key_file": {"type": "string"},
+                "jwt_header": {
+                    "type": "object",
+                    "required": ["alg"],
+                    "properties": {
+                        "alg": {"type": "string"},
+                    },
+                    "additionalProperties": {"type": "string"},
+                },
+                "jwt_payload": {
+                    "type": "object",
+                    "additionalProperties": {"type": "string"},
+                },
+            },
+        },
         "client_auth_method": {
             "type": "string",
             # the following list is the same as the keys of
@@ -404,6 +447,20 @@ def _parse_oidc_config_dict(
                 "idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
             ) from e
 
+    client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
+    client_secret_jwt_key = None  # type: Optional[OidcProviderClientSecretJwtKey]
+    if client_secret_jwt_key_config is not None:
+        keyfile = client_secret_jwt_key_config.get("key_file")
+        if keyfile:
+            key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
+        else:
+            key = client_secret_jwt_key_config["key"]
+        client_secret_jwt_key = OidcProviderClientSecretJwtKey(
+            key=key,
+            jwt_header=client_secret_jwt_key_config["jwt_header"],
+            jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
+        )
+
     return OidcProviderConfig(
         idp_id=idp_id,
         idp_name=oidc_config.get("idp_name", "OIDC"),
@@ -412,7 +469,8 @@ def _parse_oidc_config_dict(
         discover=oidc_config.get("discover", True),
         issuer=oidc_config["issuer"],
         client_id=oidc_config["client_id"],
-        client_secret=oidc_config["client_secret"],
+        client_secret=oidc_config.get("client_secret"),
+        client_secret_jwt_key=client_secret_jwt_key,
         client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
         scopes=oidc_config.get("scopes", ["openid"]),
         authorization_endpoint=oidc_config.get("authorization_endpoint"),
@@ -428,6 +486,18 @@ def _parse_oidc_config_dict(
 
 
 @attr.s(slots=True, frozen=True)
+class OidcProviderClientSecretJwtKey:
+    # a pem-encoded signing key
+    key = attr.ib(type=str)
+
+    # properties to include in the JWT header
+    jwt_header = attr.ib(type=Mapping[str, str])
+
+    # properties to include in the JWT payload.
+    jwt_payload = attr.ib(type=Mapping[str, str])
+
+
+@attr.s(slots=True, frozen=True)
 class OidcProviderConfig:
     # a unique identifier for this identity provider. Used in the 'user_external_ids'
     # table, as well as the query/path parameter used in the login protocol.
@@ -452,8 +522,13 @@ class OidcProviderConfig:
     # oauth2 client id to use
     client_id = attr.ib(type=str)
 
-    # oauth2 client secret to use
-    client_secret = attr.ib(type=str)
+    # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+    # a secret.
+    client_secret = attr.ib(type=Optional[str])
+
+    # key to use to construct a JWT to use as a client secret. May be `None` if
+    # `client_secret` is set.
+    client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
 
     # auth method to use when exchanging the token.
     # Valid values are 'client_secret_basic', 'client_secret_post' and
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index b4a74390cc..825fadb76f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,13 +15,13 @@
 # limitations under the License.
 import inspect
 import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
 from urllib.parse import urlencode
 
 import attr
 import pymacaroons
 from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken
+from authlib.jose import JsonWebToken, jwt
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
 from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@@ -35,12 +36,15 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.config.oidc_config import OidcProviderConfig
+from synapse.config.oidc_config import (
+    OidcProviderClientSecretJwtKey,
+    OidcProviderConfig,
+)
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
-from synapse.util import json_decoder
+from synapse.util import Clock, json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 
@@ -276,9 +280,21 @@ class OidcProvider:
 
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
+
+        client_secret = None  # type: Union[None, str, JwtClientSecret]
+        if provider.client_secret:
+            client_secret = provider.client_secret
+        elif provider.client_secret_jwt_key:
+            client_secret = JwtClientSecret(
+                provider.client_secret_jwt_key,
+                provider.client_id,
+                provider.issuer,
+                hs.get_clock(),
+            )
+
         self._client_auth = ClientAuth(
             provider.client_id,
-            provider.client_secret,
+            client_secret,
             provider.client_auth_method,
         )  # type: ClientAuth
         self._client_auth_method = provider.client_auth_method
@@ -977,6 +993,81 @@ class OidcProvider:
         return str(remote_user_id)
 
 
+# number of seconds a newly-generated client secret should be valid for
+CLIENT_SECRET_VALIDITY_SECONDS = 3600
+
+# minimum remaining validity on a client secret before we should generate a new one
+CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
+
+
+class JwtClientSecret:
+    """A class which generates a new client secret on demand, based on a JWK
+
+    This implementation is designed to comply with the requirements for Apple Sign in:
+    https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
+
+    It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
+    but it's worth noting that we still put the generated secret in the "client_secret"
+    field (or rather, whereever client_auth_method puts it) rather than in a
+    client_assertion field in the body as that RFC seems to require.
+    """
+
+    def __init__(
+        self,
+        key: OidcProviderClientSecretJwtKey,
+        oauth_client_id: str,
+        oauth_issuer: str,
+        clock: Clock,
+    ):
+        self._key = key
+        self._oauth_client_id = oauth_client_id
+        self._oauth_issuer = oauth_issuer
+        self._clock = clock
+        self._cached_secret = b""
+        self._cached_secret_replacement_time = 0
+
+    def __str__(self):
+        # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
+        # encode_client_secret_basic, which calls "{}".format(secret), which ends up
+        # here.
+        return self._get_secret().decode("ascii")
+
+    def __bytes__(self):
+        # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
+        # encode_client_secret_post, which ends up here.
+        return self._get_secret()
+
+    def _get_secret(self) -> bytes:
+        now = self._clock.time()
+
+        # if we have enough validity on our existing secret, use it
+        if now < self._cached_secret_replacement_time:
+            return self._cached_secret
+
+        issued_at = int(now)
+        expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
+
+        # we copy the configured header because jwt.encode modifies it.
+        header = dict(self._key.jwt_header)
+
+        # see https://tools.ietf.org/html/rfc7523#section-3
+        payload = {
+            "sub": self._oauth_client_id,
+            "aud": self._oauth_issuer,
+            "iat": issued_at,
+            "exp": expires_at,
+            **self._key.jwt_payload,
+        }
+        logger.info(
+            "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
+        )
+        self._cached_secret = jwt.encode(header, payload, self._key.key)
+        self._cached_secret_replacement_time = (
+            expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
+        )
+        return self._cached_secret
+
+
 class OidcSessionTokenGenerator:
     """Methods for generating and checking OIDC Session cookies."""