diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 1607e12935..3adc75fa4a 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -102,7 +102,7 @@ class OidcHandler:
) from e
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
- """Handle an incoming request to /_synapse/oidc/callback
+ """Handle an incoming request to /_synapse/client/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
@@ -123,7 +123,6 @@ class OidcHandler:
Args:
request: the incoming request from the browser.
"""
-
# The provider might redirect with an error.
# In that case, just display it as-is.
if b"error" in request.args:
@@ -137,8 +136,12 @@ class OidcHandler:
# either the provider misbehaving or Synapse being misconfigured.
# The only exception of that is "access_denied", where the user
# probably cancelled the login flow. In other cases, log those errors.
- if error != "access_denied":
- logger.error("Error from the OIDC provider: %s %s", error, description)
+ logger.log(
+ logging.INFO if error == "access_denied" else logging.ERROR,
+ "Received OIDC callback with error: %s %s",
+ error,
+ description,
+ )
self._sso_handler.render_error(request, error, description)
return
@@ -149,7 +152,7 @@ class OidcHandler:
# Fetch the session cookie
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
- logger.info("No session cookie found")
+ logger.info("Received OIDC callback, with no session cookie")
self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
@@ -169,7 +172,7 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
- logger.info("State parameter is missing")
+ logger.info("Received OIDC callback, with no state parameter")
self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
@@ -183,14 +186,16 @@ class OidcHandler:
session, state
)
except (MacaroonDeserializationException, ValueError) as e:
- logger.exception("Invalid session")
+ logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
- logger.exception("Could not verify session")
+ logger.exception("Could not verify session for OIDC callback")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
+ logger.info("Received OIDC callback for IdP %s", session_data.idp_id)
+
oidc_provider = self._providers.get(session_data.idp_id)
if not oidc_provider:
logger.error("OIDC session uses unknown IdP %r", oidc_provider)
@@ -274,6 +279,9 @@ class OidcProvider:
# MXC URI for icon for this auth provider
self.idp_icon = provider.idp_icon
+ # optional brand identifier for this auth provider
+ self.idp_brand = provider.idp_brand
+
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@@ -562,6 +570,7 @@ class OidcProvider:
Returns:
UserInfo: an object representing the user.
"""
+ logger.debug("Using the OAuth2 access_token to request userinfo")
metadata = await self.load_metadata()
resp = await self._http_client.get_json(
@@ -569,6 +578,8 @@ class OidcProvider:
headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
)
+ logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
+
return UserInfo(resp)
async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
@@ -597,17 +608,19 @@ class OidcProvider:
claims_cls = ImplicitIDToken
alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
-
jwt = JsonWebToken(alg_values)
claim_options = {"iss": {"values": [metadata["issuer"]]}}
+ id_token = token["id_token"]
+ logger.debug("Attempting to decode JWT id_token %r", id_token)
+
# Try to decode the keys in cache first, then retry by forcing the keys
# to be reloaded
jwk_set = await self.load_jwks()
try:
claims = jwt.decode(
- token["id_token"],
+ id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
@@ -617,13 +630,15 @@ class OidcProvider:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
claims = jwt.decode(
- token["id_token"],
+ id_token,
key=jwk_set,
claims_cls=claims_cls,
claims_options=claim_options,
claims_params=claims_params,
)
+ logger.debug("Decoded id_token JWT %r; validating", claims)
+
claims.validate(leeway=120) # allows 2 min of clock skew
return UserInfo(claims)
@@ -640,7 +655,7 @@ class OidcProvider:
- ``client_id``: the client ID set in ``oidc_config.client_id``
- ``response_type``: ``code``
- - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+ - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
@@ -681,7 +696,7 @@ class OidcProvider:
request.addCookie(
SESSION_COOKIE_NAME,
cookie,
- path="/_synapse/oidc",
+ path="/_synapse/client/oidc",
max_age="3600",
httpOnly=True,
sameSite="lax",
@@ -702,7 +717,7 @@ class OidcProvider:
async def handle_oidc_callback(
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
) -> None:
- """Handle an incoming request to /_synapse/oidc/callback
+ """Handle an incoming request to /_synapse/client/oidc/callback
By this time we have already validated the session on the synapse side, and
now need to do the provider-specific operations. This includes:
@@ -723,19 +738,18 @@ class OidcProvider:
"""
# Exchange the code with the provider
try:
- logger.debug("Exchanging code")
+ logger.debug("Exchanging OAuth2 code for a token")
token = await self._exchange_code(code)
except OidcError as e:
- logger.exception("Could not exchange code")
+ logger.exception("Could not exchange OAuth2 code")
self._sso_handler.render_error(request, e.error, e.error_description)
return
- logger.debug("Successfully obtained OAuth2 access token")
+ logger.debug("Successfully obtained OAuth2 token data: %r", token)
# Now that we have a token, get the userinfo, either by decoding the
# `id_token` or by fetching the `userinfo_endpoint`.
if self._uses_userinfo:
- logger.debug("Fetching userinfo")
try:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
@@ -743,7 +757,6 @@ class OidcProvider:
self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
- logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
@@ -1056,7 +1069,8 @@ class OidcSessionData:
UserAttributeDict = TypedDict(
- "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
+ "UserAttributeDict",
+ {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
)
C = TypeVar("C")
@@ -1135,11 +1149,12 @@ def jinja_finalize(thing):
env = Environment(finalize=jinja_finalize)
-@attr.s
+@attr.s(slots=True, frozen=True)
class JinjaOidcMappingConfig:
subject_claim = attr.ib(type=str)
localpart_template = attr.ib(type=Optional[Template])
display_name_template = attr.ib(type=Optional[Template])
+ email_template = attr.ib(type=Optional[Template])
extra_attributes = attr.ib(type=Dict[str, Template])
@@ -1156,23 +1171,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
- localpart_template = None # type: Optional[Template]
- if "localpart_template" in config:
+ def parse_template_config(option_name: str) -> Optional[Template]:
+ if option_name not in config:
+ return None
try:
- localpart_template = env.from_string(config["localpart_template"])
+ return env.from_string(config[option_name])
except Exception as e:
- raise ConfigError(
- "invalid jinja template", path=["localpart_template"]
- ) from e
+ raise ConfigError("invalid jinja template", path=[option_name]) from e
- display_name_template = None # type: Optional[Template]
- if "display_name_template" in config:
- try:
- display_name_template = env.from_string(config["display_name_template"])
- except Exception as e:
- raise ConfigError(
- "invalid jinja template", path=["display_name_template"]
- ) from e
+ localpart_template = parse_template_config("localpart_template")
+ display_name_template = parse_template_config("display_name_template")
+ email_template = parse_template_config("email_template")
extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config:
@@ -1192,6 +1201,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
+ email_template=email_template,
extra_attributes=extra_attributes,
)
@@ -1213,16 +1223,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
# a usable mxid.
localpart += str(failures) if failures else ""
- display_name = None # type: Optional[str]
- if self._config.display_name_template is not None:
- display_name = self._config.display_name_template.render(
- user=userinfo
- ).strip()
+ def render_template_field(template: Optional[Template]) -> Optional[str]:
+ if template is None:
+ return None
+ return template.render(user=userinfo).strip()
+
+ display_name = render_template_field(self._config.display_name_template)
+ if display_name == "":
+ display_name = None
- if display_name == "":
- display_name = None
+ emails = [] # type: List[str]
+ email = render_template_field(self._config.email_template)
+ if email:
+ emails.append(email)
- return UserAttributeDict(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(
+ localpart=localpart, display_name=display_name, emails=emails
+ )
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
|