OIDC: try to JWT decode userinfo response if JSON parsing failed (#16972)
2 files changed, 29 insertions, 4 deletions
diff --git a/changelog.d/16972.feature b/changelog.d/16972.feature
new file mode 100644
index 0000000000..0f28cbbcd6
--- /dev/null
+++ b/changelog.d/16972.feature
@@ -0,0 +1 @@
+OIDC: try to JWT decode userinfo response if JSON parsing failed.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index ba67cc4768..ab28dc800e 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -829,14 +829,38 @@ class OidcProvider:
logger.debug("Using the OAuth2 access_token to request userinfo")
metadata = await self.load_metadata()
- resp = await self._http_client.get_json(
+ resp = await self._http_client.request(
+ "GET",
metadata["userinfo_endpoint"],
- headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
+ headers=Headers(
+ {"Authorization": ["Bearer {}".format(token["access_token"])]}
+ ),
)
- logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
+ body = await readBody(resp)
+
+ content_type_headers = resp.headers.getRawHeaders("Content-Type")
+ assert content_type_headers
+ # We use `startswith` because the header value can contain the `charset` parameter
+ # even if it is useless, and Twisted doesn't take care of that for us.
+ if content_type_headers[0].startswith("application/jwt"):
+ alg_values = metadata.get(
+ "id_token_signing_alg_values_supported", ["RS256"]
+ )
+ jwt = JsonWebToken(alg_values)
+ jwk_set = await self.load_jwks()
+ try:
+ decoded_resp = jwt.decode(body, key=jwk_set)
+ except ValueError:
+ logger.info("Reloading JWKS after decode error")
+ jwk_set = await self.load_jwks(force=True) # try reloading the jwks
+ decoded_resp = jwt.decode(body, key=jwk_set)
+ else:
+ decoded_resp = json_decoder.decode(body.decode("utf-8"))
+
+ logger.debug("Retrieved user info from userinfo endpoint: %r", decoded_resp)
- return UserInfo(resp)
+ return UserInfo(decoded_resp)
async def _verify_jwt(
self,
|