diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 3665d91513..deb3539751 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ class OidcHandler:
for idp_id, p in self._providers.items():
try:
await p.load_metadata()
- await p.load_jwks()
+ if not p._uses_userinfo:
+ await p.load_jwks()
except Exception as e:
raise Exception(
"Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ class OidcProvider:
return await self._jwks.get()
async def _load_jwks(self) -> JWKS:
- if self._uses_userinfo:
- # We're not using jwt signing, return an empty jwk set
- return {"keys": []}
-
metadata = await self.load_metadata()
# Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ class OidcProvider:
return UserInfo(resp)
- async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+ async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
"""Return an instance of UserInfo from token's ``id_token``.
Args:
@@ -673,7 +670,7 @@ class OidcProvider:
request. This value should match the one inside the token.
Returns:
- An object representing the user.
+ The decoded claims in the ID token.
"""
metadata = await self.load_metadata()
claims_params = {
@@ -684,9 +681,6 @@ class OidcProvider:
# If we got an `access_token`, there should be an `at_hash` claim
# in the `id_token` that we can check against.
claims_params["access_token"] = token["access_token"]
- claims_cls = CodeIDToken
- else:
- claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=claims_cls,
+ claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -713,7 +707,7 @@ class OidcProvider:
claims = jwt.decode(
id_token,
key=jwk_set,
- claims_cls=claims_cls,
+ claims_cls=CodeIDToken,
claims_options=claim_options,
claims_params=claims_params,
)
@@ -721,7 +715,8 @@ class OidcProvider:
logger.debug("Decoded id_token JWT %r; validating", claims)
claims.validate(leeway=120) # allows 2 min of clock skew
- return UserInfo(claims)
+
+ return claims
async def handle_redirect_request(
self,
@@ -837,8 +832,22 @@ class OidcProvider:
logger.debug("Successfully obtained OAuth2 token data: %r", token)
- # Now that we have a token, get the userinfo, either by decoding the
- # `id_token` or by fetching the `userinfo_endpoint`.
+ # If there is an id_token, it should be validated, regardless of the
+ # userinfo endpoint is used or not.
+ if token.get("id_token") is not None:
+ try:
+ id_token = await self._parse_id_token(token, nonce=session_data.nonce)
+ sid = id_token.get("sid")
+ except Exception as e:
+ logger.exception("Invalid id_token")
+ self._sso_handler.render_error(request, "invalid_token", str(e))
+ return
+ else:
+ id_token = None
+ sid = None
+
+ # Now that we have a token, get the userinfo either from the `id_token`
+ # claims or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
try:
userinfo = await self._fetch_userinfo(token)
@@ -846,13 +855,14 @@ class OidcProvider:
logger.exception("Could not fetch userinfo")
self._sso_handler.render_error(request, "fetch_error", str(e))
return
+ elif id_token is not None:
+ userinfo = UserInfo(id_token)
else:
- try:
- userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
- except Exception as e:
- logger.exception("Invalid id_token")
- self._sso_handler.render_error(request, "invalid_token", str(e))
- return
+ logger.error("Missing id_token in token response")
+ self._sso_handler.render_error(
+ request, "invalid_token", "Missing id_token in token response"
+ )
+ return
# first check if we're doing a UIA
if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ class OidcProvider:
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
- userinfo, token, request, session_data.client_redirect_url
+ userinfo, token, request, session_data.client_redirect_url, sid
)
except MappingException as e:
logger.exception("Could not map user")
@@ -896,6 +906,7 @@ class OidcProvider:
token: Token,
request: SynapseRequest,
client_redirect_url: str,
+ sid: Optional[str],
) -> None:
"""Given a UserInfo response, complete the login flow
@@ -1008,6 +1019,7 @@ class OidcProvider:
oidc_response_to_user_attributes,
grandfather_existing_users,
extra_attributes,
+ auth_provider_session_id=sid,
)
def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
|