summary refs log tree commit diff
path: root/synapse/handlers/oidc.py
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2021-12-06 18:43:06 +0100
committerGitHub <noreply@github.com>2021-12-06 12:43:06 -0500
commita15a893df8428395df7cb95b729431575001c38a (patch)
tree7572abf2fa680c942dc882cc05e9062bb63b55b8 /synapse/handlers/oidc.py
parentAdd admin API to get some information about federation status (#11407) (diff)
downloadsynapse-a15a893df8428395df7cb95b729431575001c38a.tar.xz
Save the OIDC session ID (sid) with the device on login (#11482)
As a step towards allowing back-channel logout for OIDC.
Diffstat (limited to 'synapse/handlers/oidc.py')
-rw-r--r--synapse/handlers/oidc.py58
1 files changed, 35 insertions, 23 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 3665d91513..deb3539751 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
 from authlib.jose import JsonWebToken, jwt
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, UserInfo
 from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
 from jinja2 import Environment, Template
 from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ class OidcHandler:
         for idp_id, p in self._providers.items():
             try:
                 await p.load_metadata()
-                await p.load_jwks()
+                if not p._uses_userinfo:
+                    await p.load_jwks()
             except Exception as e:
                 raise Exception(
                     "Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ class OidcProvider:
         return await self._jwks.get()
 
     async def _load_jwks(self) -> JWKS:
-        if self._uses_userinfo:
-            # We're not using jwt signing, return an empty jwk set
-            return {"keys": []}
-
         metadata = await self.load_metadata()
 
         # Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ class OidcProvider:
 
         return UserInfo(resp)
 
-    async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
         """Return an instance of UserInfo from token's ``id_token``.
 
         Args:
@@ -673,7 +670,7 @@ class OidcProvider:
                 request. This value should match the one inside the token.
 
         Returns:
-            An object representing the user.
+            The decoded claims in the ID token.
         """
         metadata = await self.load_metadata()
         claims_params = {
@@ -684,9 +681,6 @@ class OidcProvider:
             # If we got an `access_token`, there should be an `at_hash` claim
             # in the `id_token` that we can check against.
             claims_params["access_token"] = token["access_token"]
-            claims_cls = CodeIDToken
-        else:
-            claims_cls = ImplicitIDToken
 
         alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
         jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -713,7 +707,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -721,7 +715,8 @@ class OidcProvider:
         logger.debug("Decoded id_token JWT %r; validating", claims)
 
         claims.validate(leeway=120)  # allows 2 min of clock skew
-        return UserInfo(claims)
+
+        return claims
 
     async def handle_redirect_request(
         self,
@@ -837,8 +832,22 @@ class OidcProvider:
 
         logger.debug("Successfully obtained OAuth2 token data: %r", token)
 
-        # Now that we have a token, get the userinfo, either by decoding the
-        # `id_token` or by fetching the `userinfo_endpoint`.
+        # If there is an id_token, it should be validated, regardless of the
+        # userinfo endpoint is used or not.
+        if token.get("id_token") is not None:
+            try:
+                id_token = await self._parse_id_token(token, nonce=session_data.nonce)
+                sid = id_token.get("sid")
+            except Exception as e:
+                logger.exception("Invalid id_token")
+                self._sso_handler.render_error(request, "invalid_token", str(e))
+                return
+        else:
+            id_token = None
+            sid = None
+
+        # Now that we have a token, get the userinfo either from the `id_token`
+        # claims or by fetching the `userinfo_endpoint`.
         if self._uses_userinfo:
             try:
                 userinfo = await self._fetch_userinfo(token)
@@ -846,13 +855,14 @@ class OidcProvider:
                 logger.exception("Could not fetch userinfo")
                 self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
+        elif id_token is not None:
+            userinfo = UserInfo(id_token)
         else:
-            try:
-                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
-            except Exception as e:
-                logger.exception("Invalid id_token")
-                self._sso_handler.render_error(request, "invalid_token", str(e))
-                return
+            logger.error("Missing id_token in token response")
+            self._sso_handler.render_error(
+                request, "invalid_token", "Missing id_token in token response"
+            )
+            return
 
         # first check if we're doing a UIA
         if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ class OidcProvider:
         # Call the mapper to register/login the user
         try:
             await self._complete_oidc_login(
-                userinfo, token, request, session_data.client_redirect_url
+                userinfo, token, request, session_data.client_redirect_url, sid
             )
         except MappingException as e:
             logger.exception("Could not map user")
@@ -896,6 +906,7 @@ class OidcProvider:
         token: Token,
         request: SynapseRequest,
         client_redirect_url: str,
+        sid: Optional[str],
     ) -> None:
         """Given a UserInfo response, complete the login flow
 
@@ -1008,6 +1019,7 @@ class OidcProvider:
             oidc_response_to_user_attributes,
             grandfather_existing_users,
             extra_attributes,
+            auth_provider_session_id=sid,
         )
 
     def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: