diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 825fadb76f..6d8551a6d6 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -29,11 +29,13 @@ from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from jinja2 import Environment, Template
from pymacaroons.exceptions import (
MacaroonDeserializationException,
+ MacaroonInitException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
+from twisted.web.http_headers import Headers
from synapse.config import ConfigError
from synapse.config.oidc_config import (
@@ -216,7 +218,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
- except (MacaroonDeserializationException, KeyError) as e:
+ except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -329,6 +331,9 @@ class OidcProvider:
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
+ # Optional brand identifier for the unstable API (see MSC2858).
+ self.unstable_idp_brand = provider.unstable_idp_brand
+
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@@ -538,7 +543,7 @@ class OidcProvider:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
- headers = {
+ raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
@@ -552,10 +557,10 @@ class OidcProvider:
body = urlencode(args, True)
# Fill the body/headers with credentials
- uri, headers, body = self._client_auth.prepare(
- method="POST", uri=token_endpoint, headers=headers, body=body
+ uri, raw_headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
- headers = {k: [v] for (k, v) in headers.items()}
+ headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
|