summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16972.feature1
-rw-r--r--synapse/handlers/oidc.py32
2 files changed, 29 insertions, 4 deletions
diff --git a/changelog.d/16972.feature b/changelog.d/16972.feature
new file mode 100644
index 0000000000..0f28cbbcd6
--- /dev/null
+++ b/changelog.d/16972.feature
@@ -0,0 +1 @@
+OIDC: try to JWT decode userinfo response if JSON parsing failed.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index ba67cc4768..ab28dc800e 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -829,14 +829,38 @@ class OidcProvider:
         logger.debug("Using the OAuth2 access_token to request userinfo")
         metadata = await self.load_metadata()
 
-        resp = await self._http_client.get_json(
+        resp = await self._http_client.request(
+            "GET",
             metadata["userinfo_endpoint"],
-            headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
+            headers=Headers(
+                {"Authorization": ["Bearer {}".format(token["access_token"])]}
+            ),
         )
 
-        logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
+        body = await readBody(resp)
+
+        content_type_headers = resp.headers.getRawHeaders("Content-Type")
+        assert content_type_headers
+        # We use `startswith` because the header value can contain the `charset` parameter
+        # even if it is useless, and Twisted doesn't take care of that for us.
+        if content_type_headers[0].startswith("application/jwt"):
+            alg_values = metadata.get(
+                "id_token_signing_alg_values_supported", ["RS256"]
+            )
+            jwt = JsonWebToken(alg_values)
+            jwk_set = await self.load_jwks()
+            try:
+                decoded_resp = jwt.decode(body, key=jwk_set)
+            except ValueError:
+                logger.info("Reloading JWKS after decode error")
+                jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
+                decoded_resp = jwt.decode(body, key=jwk_set)
+        else:
+            decoded_resp = json_decoder.decode(body.decode("utf-8"))
+
+        logger.debug("Retrieved user info from userinfo endpoint: %r", decoded_resp)
 
-        return UserInfo(resp)
+        return UserInfo(decoded_resp)
 
     async def _verify_jwt(
         self,