diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 03de6a4ba6..0fc829acf7 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -36,6 +36,7 @@ from authlib.jose import JsonWebToken, JWTClaims
from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
+from authlib.oauth2.rfc7636.challenge import create_s256_code_challenge
from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
@@ -475,6 +476,16 @@ class OidcProvider:
)
)
+ # If PKCE support is advertised ensure the wanted method is available.
+ if m.get("code_challenge_methods_supported") is not None:
+ m.validate_code_challenge_methods_supported()
+ if "S256" not in m["code_challenge_methods_supported"]:
+ raise ValueError(
+ '"S256" not in "code_challenge_methods_supported" ({supported!r})'.format(
+ supported=m["code_challenge_methods_supported"],
+ )
+ )
+
if m.get("response_types_supported") is not None:
m.validate_response_types_supported()
@@ -602,6 +613,11 @@ class OidcProvider:
if self._config.jwks_uri:
metadata["jwks_uri"] = self._config.jwks_uri
+ if self._config.pkce_method == "always":
+ metadata["code_challenge_methods_supported"] = ["S256"]
+ elif self._config.pkce_method == "never":
+ metadata.pop("code_challenge_methods_supported", None)
+
self._validate_metadata(metadata)
return metadata
@@ -653,7 +669,7 @@ class OidcProvider:
return jwk_set
- async def _exchange_code(self, code: str) -> Token:
+ async def _exchange_code(self, code: str, code_verifier: str) -> Token:
"""Exchange an authorization code for a token.
This calls the ``token_endpoint`` with the authorization code we
@@ -666,6 +682,7 @@ class OidcProvider:
Args:
code: The authorization code we got from the callback.
+ code_verifier: The PKCE code verifier to send, blank if unused.
Returns:
A dict containing various tokens.
@@ -696,6 +713,8 @@ class OidcProvider:
"code": code,
"redirect_uri": self._callback_url,
}
+ if code_verifier:
+ args["code_verifier"] = code_verifier
body = urlencode(args, True)
# Fill the body/headers with credentials
@@ -914,11 +933,14 @@ class OidcProvider:
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
+ - ``code_challenge``: a RFC7636 code challenge (if PKCE is supported)
- In addition generating a redirect URL, we are setting a cookie with
- a signed macaroon token containing the state, the nonce and the
- client_redirect_url params. Those are then checked when the client
- comes back from the provider.
+ In addition to generating a redirect URL, we are setting a cookie with
+ a signed macaroon token containing the state, the nonce, the
+ client_redirect_url, and (optionally) the code_verifier params. The state,
+ nonce, and client_redirect_url are then checked when the client comes back
+ from the provider. The code_verifier is passed back to the server during
+ the token exchange and compared to the code_challenge sent in this request.
Args:
request: the incoming request from the browser.
@@ -935,10 +957,25 @@ class OidcProvider:
state = generate_token()
nonce = generate_token()
+ code_verifier = ""
if not client_redirect_url:
client_redirect_url = b""
+ metadata = await self.load_metadata()
+
+ # Automatically enable PKCE if it is supported.
+ extra_grant_values = {}
+ if metadata.get("code_challenge_methods_supported"):
+ code_verifier = generate_token(48)
+
+ # Note that we verified the server supports S256 earlier (in
+ # OidcProvider._validate_metadata).
+ extra_grant_values = {
+ "code_challenge_method": "S256",
+ "code_challenge": create_s256_code_challenge(code_verifier),
+ }
+
cookie = self._macaroon_generaton.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
@@ -946,6 +983,7 @@ class OidcProvider:
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id or "",
+ code_verifier=code_verifier,
),
)
@@ -966,7 +1004,6 @@ class OidcProvider:
)
)
- metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
return prepare_grant_uri(
authorization_endpoint,
@@ -976,6 +1013,7 @@ class OidcProvider:
scope=self._scopes,
state=state,
nonce=nonce,
+ **extra_grant_values,
)
async def handle_oidc_callback(
@@ -1003,7 +1041,9 @@ class OidcProvider:
# Exchange the code with the provider
try:
logger.debug("Exchanging OAuth2 code for a token")
- token = await self._exchange_code(code)
+ token = await self._exchange_code(
+ code, code_verifier=session_data.code_verifier
+ )
except OidcError as e:
logger.warning("Could not exchange OAuth2 code: %s", e)
self._sso_handler.render_error(request, e.error, e.error_description)
@@ -1520,8 +1560,8 @@ env.filters.update(
@attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig:
- subject_claim: str
- picture_claim: str
+ subject_template: Template
+ picture_template: Template
localpart_template: Optional[Template]
display_name_template: Optional[Template]
email_template: Optional[Template]
@@ -1540,8 +1580,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@staticmethod
def parse_config(config: dict) -> JinjaOidcMappingConfig:
- subject_claim = config.get("subject_claim", "sub")
- picture_claim = config.get("picture_claim", "picture")
+ def parse_template_config_with_claim(
+ option_name: str, default_claim: str
+ ) -> Template:
+ template_name = f"{option_name}_template"
+ template = config.get(template_name)
+ if not template:
+ # Convert the legacy subject_claim into a template.
+ claim = config.get(f"{option_name}_claim", default_claim)
+ template = "{{ user.%s }}" % (claim,)
+
+ try:
+ return env.from_string(template)
+ except Exception as e:
+ raise ConfigError("invalid jinja template", path=[template_name]) from e
+
+ subject_template = parse_template_config_with_claim("subject", "sub")
+ picture_template = parse_template_config_with_claim("picture", "picture")
def parse_template_config(option_name: str) -> Optional[Template]:
if option_name not in config:
@@ -1574,8 +1629,8 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
raise ConfigError("must be a bool", path=["confirm_localpart"])
return JinjaOidcMappingConfig(
- subject_claim=subject_claim,
- picture_claim=picture_claim,
+ subject_template=subject_template,
+ picture_template=picture_template,
localpart_template=localpart_template,
display_name_template=display_name_template,
email_template=email_template,
@@ -1584,7 +1639,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
- return userinfo[self._config.subject_claim]
+ return self._config.subject_template.render(user=userinfo).strip()
async def map_user_attributes(
self, userinfo: UserInfo, token: Token, failures: int
@@ -1615,7 +1670,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if email:
emails.append(email)
- picture = userinfo.get("picture")
+ picture = self._config.picture_template.render(user=userinfo).strip()
return UserAttributeDict(
localpart=localpart,
|