summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/oidc.py54
1 files changed, 47 insertions, 7 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 24e1cec5b6..0fc829acf7 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -36,6 +36,7 @@ from authlib.jose import JsonWebToken, JWTClaims
 from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
+from authlib.oauth2.rfc7636.challenge import create_s256_code_challenge
 from authlib.oidc.core import CodeIDToken, UserInfo
 from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
 from jinja2 import Environment, Template
@@ -475,6 +476,16 @@ class OidcProvider:
                     )
                 )
 
+        # If PKCE support is advertised ensure the wanted method is available.
+        if m.get("code_challenge_methods_supported") is not None:
+            m.validate_code_challenge_methods_supported()
+            if "S256" not in m["code_challenge_methods_supported"]:
+                raise ValueError(
+                    '"S256" not in "code_challenge_methods_supported" ({supported!r})'.format(
+                        supported=m["code_challenge_methods_supported"],
+                    )
+                )
+
         if m.get("response_types_supported") is not None:
             m.validate_response_types_supported()
 
@@ -602,6 +613,11 @@ class OidcProvider:
         if self._config.jwks_uri:
             metadata["jwks_uri"] = self._config.jwks_uri
 
+        if self._config.pkce_method == "always":
+            metadata["code_challenge_methods_supported"] = ["S256"]
+        elif self._config.pkce_method == "never":
+            metadata.pop("code_challenge_methods_supported", None)
+
         self._validate_metadata(metadata)
 
         return metadata
@@ -653,7 +669,7 @@ class OidcProvider:
 
         return jwk_set
 
-    async def _exchange_code(self, code: str) -> Token:
+    async def _exchange_code(self, code: str, code_verifier: str) -> Token:
         """Exchange an authorization code for a token.
 
         This calls the ``token_endpoint`` with the authorization code we
@@ -666,6 +682,7 @@ class OidcProvider:
 
         Args:
             code: The authorization code we got from the callback.
+            code_verifier: The PKCE code verifier to send, blank if unused.
 
         Returns:
             A dict containing various tokens.
@@ -696,6 +713,8 @@ class OidcProvider:
             "code": code,
             "redirect_uri": self._callback_url,
         }
+        if code_verifier:
+            args["code_verifier"] = code_verifier
         body = urlencode(args, True)
 
         # Fill the body/headers with credentials
@@ -914,11 +933,14 @@ class OidcProvider:
           - ``scope``: the list of scopes set in ``oidc_config.scopes``
           - ``state``: a random string
           - ``nonce``: a random string
+          - ``code_challenge``: a RFC7636 code challenge (if PKCE is supported)
 
-        In addition generating a redirect URL, we are setting a cookie with
-        a signed macaroon token containing the state, the nonce and the
-        client_redirect_url params. Those are then checked when the client
-        comes back from the provider.
+        In addition to generating a redirect URL, we are setting a cookie with
+        a signed macaroon token containing the state, the nonce, the
+        client_redirect_url, and (optionally) the code_verifier params. The state,
+        nonce, and client_redirect_url are then checked when the client comes back
+        from the provider. The code_verifier is passed back to the server during
+        the token exchange and compared to the code_challenge sent in this request.
 
         Args:
             request: the incoming request from the browser.
@@ -935,10 +957,25 @@ class OidcProvider:
 
         state = generate_token()
         nonce = generate_token()
+        code_verifier = ""
 
         if not client_redirect_url:
             client_redirect_url = b""
 
+        metadata = await self.load_metadata()
+
+        # Automatically enable PKCE if it is supported.
+        extra_grant_values = {}
+        if metadata.get("code_challenge_methods_supported"):
+            code_verifier = generate_token(48)
+
+            # Note that we verified the server supports S256 earlier (in
+            # OidcProvider._validate_metadata).
+            extra_grant_values = {
+                "code_challenge_method": "S256",
+                "code_challenge": create_s256_code_challenge(code_verifier),
+            }
+
         cookie = self._macaroon_generaton.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
@@ -946,6 +983,7 @@ class OidcProvider:
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
                 ui_auth_session_id=ui_auth_session_id or "",
+                code_verifier=code_verifier,
             ),
         )
 
@@ -966,7 +1004,6 @@ class OidcProvider:
                 )
             )
 
-        metadata = await self.load_metadata()
         authorization_endpoint = metadata.get("authorization_endpoint")
         return prepare_grant_uri(
             authorization_endpoint,
@@ -976,6 +1013,7 @@ class OidcProvider:
             scope=self._scopes,
             state=state,
             nonce=nonce,
+            **extra_grant_values,
         )
 
     async def handle_oidc_callback(
@@ -1003,7 +1041,9 @@ class OidcProvider:
         # Exchange the code with the provider
         try:
             logger.debug("Exchanging OAuth2 code for a token")
-            token = await self._exchange_code(code)
+            token = await self._exchange_code(
+                code, code_verifier=session_data.code_verifier
+            )
         except OidcError as e:
             logger.warning("Could not exchange OAuth2 code: %s", e)
             self._sso_handler.render_error(request, e.error, e.error_description)