diff options
Diffstat (limited to 'synapse/handlers/oidc.py')
-rw-r--r-- | synapse/handlers/oidc.py | 395 |
1 files changed, 359 insertions, 36 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index d7a8226900..03de6a4ba6 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -12,14 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import binascii import inspect +import json import logging -from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, +) from urllib.parse import urlencode, urlparse import attr +import unpaddedbase64 from authlib.common.security import generate_token -from authlib.jose import JsonWebToken, jwt +from authlib.jose import JsonWebToken, JWTClaims +from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, UserInfo @@ -35,9 +49,12 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from twisted.web.http_headers import Headers +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig from synapse.handlers.sso import MappingException, UserAttributes +from synapse.http.server import finish_request +from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart @@ -88,6 +105,8 @@ class Token(TypedDict): #: there is no real point of doing this in our case. JWK = Dict[str, str] +C = TypeVar("C") + #: A JWK Set, as per RFC7517 sec 5. class JWKS(TypedDict): @@ -247,6 +266,80 @@ class OidcHandler: await oidc_provider.handle_oidc_callback(request, session_data, code) + async def handle_backchannel_logout(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + This extracts the logout_token from the request and tries to figure out + which OpenID Provider it is comming from. This works by matching the iss claim + with the issuer and the aud claim with the client_id. + + Since at this point we don't know who signed the JWT, we can't just + decode it using authlib since it will always verifies the signature. We + have to decode it manually without validating the signature. The actual JWT + verification is done in the `OidcProvider.handler_backchannel_logout` method, + once we figured out which provider sent the request. + + Args: + request: the incoming request from the browser. + """ + logout_token = parse_string(request, "logout_token") + if logout_token is None: + raise SynapseError(400, "Missing logout_token in request") + + # A JWT looks like this: + # header.payload.signature + # where all parts are encoded with urlsafe base64. + # The aud and iss claims we care about are in the payload part, which + # is a JSON object. + try: + # By destructuring the list after splitting, we ensure that we have + # exactly 3 segments + _, payload, _ = logout_token.split(".") + except ValueError: + raise SynapseError(400, "Invalid logout_token in request") + + try: + payload_bytes = unpaddedbase64.decode_base64(payload) + claims = json_decoder.decode(payload_bytes.decode("utf-8")) + except (json.JSONDecodeError, binascii.Error, UnicodeError): + raise SynapseError(400, "Invalid logout_token payload in request") + + try: + # Let's extract the iss and aud claims + iss = claims["iss"] + aud = claims["aud"] + # The aud claim can be either a string or a list of string. Here we + # normalize it as a list of strings. + if isinstance(aud, str): + aud = [aud] + + # Check that we have the right types for the aud and the iss claims + if not isinstance(iss, str) or not isinstance(aud, list): + raise TypeError() + for a in aud: + if not isinstance(a, str): + raise TypeError() + + # At this point we properly checked both claims types + issuer: str = iss + audience: List[str] = aud + except (TypeError, KeyError): + raise SynapseError(400, "Invalid issuer/audience in logout_token") + + # Now that we know the audience and the issuer, we can figure out from + # what provider it is coming from + oidc_provider: Optional[OidcProvider] = None + for provider in self._providers.values(): + if provider.issuer == issuer and provider.client_id in audience: + oidc_provider = provider + break + + if oidc_provider is None: + raise SynapseError(400, "Could not find the OP that issued this event") + + # Ask the provider to handle the logout request. + await oidc_provider.handle_backchannel_logout(request, logout_token) + class OidcError(Exception): """Used to catch errors when calling the token_endpoint""" @@ -275,6 +368,7 @@ class OidcProvider: provider: OidcProviderConfig, ): self._store = hs.get_datastores().main + self._clock = hs.get_clock() self._macaroon_generaton = macaroon_generator @@ -341,6 +435,7 @@ class OidcProvider: self.idp_brand = provider.idp_brand self._sso_handler = hs.get_sso_handler() + self._device_handler = hs.get_device_handler() self._sso_handler.register_identity_provider(self) @@ -399,6 +494,41 @@ class OidcProvider: # If we're not using userinfo, we need a valid jwks to validate the ID token m.validate_jwks_uri() + if self._config.backchannel_logout_enabled: + if not m.get("backchannel_logout_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled for issuer %r" + "but it does not advertise support for it", + self.issuer, + ) + + elif not m.get("backchannel_logout_session_supported", False): + logger.warning( + "OIDC Back-Channel Logout is enabled and supported " + "by issuer %r but it might not send a session ID with " + "logout tokens, which is required for the logouts to work", + self.issuer, + ) + + if not self._config.backchannel_logout_ignore_sub: + # If OIDC backchannel logouts are enabled, the provider mapping provider + # should use the `sub` claim. We verify that by mapping a dumb user and + # see if we get back the sub claim + user = UserInfo({"sub": "thisisasubject"}) + try: + subject = self._user_mapping_provider.get_remote_user_id(user) + if subject != user["sub"]: + raise ValueError("Unexpected subject") + except Exception: + logger.warning( + f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} " + "but it looks like the configured `user_mapping_provider` " + "does not use the `sub` claim as subject. If it is the case, " + "and you want Synapse to ignore the `sub` claim in OIDC " + "Back-Channel Logouts, set `backchannel_logout_ignore_sub` " + "to `true` in the issuer config." + ) + @property def _uses_userinfo(self) -> bool: """Returns True if the ``userinfo_endpoint`` should be used. @@ -414,6 +544,16 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) + @property + def issuer(self) -> str: + """The issuer identifying this provider.""" + return self._config.issuer + + @property + def client_id(self) -> str: + """The client_id used when interacting with this provider.""" + return self._config.client_id + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: """Return the provider metadata. @@ -647,7 +787,7 @@ class OidcProvider: Must include an ``access_token`` field. Returns: - UserInfo: an object representing the user. + an object representing the user. """ logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() @@ -661,61 +801,99 @@ class OidcProvider: return UserInfo(resp) - async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: - """Return an instance of UserInfo from token's ``id_token``. + async def _verify_jwt( + self, + alg_values: List[str], + token: str, + claims_cls: Type[C], + claims_options: Optional[dict] = None, + claims_params: Optional[dict] = None, + ) -> C: + """Decode and validate a JWT, re-fetching the JWKS as needed. Args: - token: the token given by the ``token_endpoint``. - Must include an ``id_token`` field. - nonce: the nonce value originally sent in the initial authorization - request. This value should match the one inside the token. + alg_values: list of `alg` values allowed when verifying the JWT. + token: the JWT. + claims_cls: the JWTClaims class to use to validate the claims. + claims_options: dict of options passed to the `claims_cls` constructor. + claims_params: dict of params passed to the `claims_cls` constructor. Returns: - The decoded claims in the ID token. + The decoded claims in the JWT. """ - metadata = await self.load_metadata() - claims_params = { - "nonce": nonce, - "client_id": self._client_auth.client_id, - } - if "access_token" in token: - # If we got an `access_token`, there should be an `at_hash` claim - # in the `id_token` that we can check against. - claims_params["access_token"] = token["access_token"] - - alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) jwt = JsonWebToken(alg_values) - claim_options = {"iss": {"values": [metadata["issuer"]]}} - - id_token = token["id_token"] - logger.debug("Attempting to decode JWT id_token %r", id_token) + logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token) # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() try: claims = jwt.decode( - id_token, + token, key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, + claims_cls=claims_cls, + claims_options=claims_options, claims_params=claims_params, ) except ValueError: logger.info("Reloading JWKS after decode error") jwk_set = await self.load_jwks(force=True) # try reloading the jwks claims = jwt.decode( - id_token, + token, key=jwk_set, - claims_cls=CodeIDToken, - claims_options=claim_options, + claims_cls=claims_cls, + claims_options=claims_options, claims_params=claims_params, ) - logger.debug("Decoded id_token JWT %r; validating", claims) + logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims) - claims.validate(leeway=120) # allows 2 min of clock skew + claims.validate( + now=self._clock.time(), leeway=120 + ) # allows 2 min of clock skew + return claims + + async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken: + """Return an instance of UserInfo from token's ``id_token``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``id_token`` field. + nonce: the nonce value originally sent in the initial authorization + request. This value should match the one inside the token. + + Returns: + The decoded claims in the ID token. + """ + id_token = token.get("id_token") + + # That has been theoritically been checked by the caller, so even though + # assertion are not enabled in production, it is mainly here to appease mypy + assert id_token is not None + + metadata = await self.load_metadata() + + claims_params = { + "nonce": nonce, + "client_id": self._client_auth.client_id, + } + if "access_token" in token: + # If we got an `access_token`, there should be an `at_hash` claim + # in the `id_token` that we can check against. + claims_params["access_token"] = token["access_token"] + + claims_options = {"iss": {"values": [metadata["issuer"]]}} + + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + claims = await self._verify_jwt( + alg_values=alg_values, + token=id_token, + claims_cls=CodeIDToken, + claims_options=claims_options, + claims_params=claims_params, + ) return claims @@ -1036,6 +1214,146 @@ class OidcProvider: # to be strings. return str(remote_user_id) + async def handle_backchannel_logout( + self, request: SynapseRequest, logout_token: str + ) -> None: + """Handle an incoming request to /_synapse/client/oidc/backchannel_logout + + The OIDC Provider posts a logout token to this endpoint when a user + session ends. That token is a JWT signed with the same keys as + ID tokens. The OpenID Connect Back-Channel Logout draft explains how to + validate the JWT and figure out what session to end. + + Args: + request: The request to respond to + logout_token: The logout token (a JWT) extracted from the request body + """ + # Back-Channel Logout can be disabled in the config, hence this check. + # This is not that important for now since Synapse is registered + # manually to the OP, so not specifying the backchannel-logout URI is + # as effective than disabling it here. It might make more sense if we + # support dynamic registration in Synapse at some point. + if not self._config.backchannel_logout_enabled: + logger.warning( + f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config" + ) + + # TODO: this responds with a 400 status code, which is what the OIDC + # Back-Channel Logout spec expects, but spec also suggests answering with + # a JSON object, with the `error` and `error_description` fields set, which + # we are not doing here. + # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse + raise SynapseError( + 400, "OpenID Connect Back-Channel Logout is disabled for this provider" + ) + + metadata = await self.load_metadata() + + # As per OIDC Back-Channel Logout 1.0 sec. 2.4: + # A Logout Token MUST be signed and MAY also be encrypted. The same + # keys are used to sign and encrypt Logout Tokens as are used for ID + # Tokens. If the Logout Token is encrypted, it SHOULD replicate the + # iss (issuer) claim in the JWT Header Parameters, as specified in + # Section 5.3 of [JWT]. + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + + # As per sec. 2.6: + # 3. Validate the iss, aud, and iat Claims in the same way they are + # validated in ID Tokens. + # Which means the audience should contain Synapse's client_id and the + # issuer should be the IdP issuer + claims_options = { + "iss": {"values": [metadata["issuer"]]}, + "aud": {"values": [self.client_id]}, + } + + try: + claims = await self._verify_jwt( + alg_values=alg_values, + token=logout_token, + claims_cls=LogoutToken, + claims_options=claims_options, + ) + except JoseError: + logger.exception("Invalid logout_token") + raise SynapseError(400, "Invalid logout_token") + + # As per sec. 2.6: + # 4. Verify that the Logout Token contains a sub Claim, a sid Claim, + # or both. + # 5. Verify that the Logout Token contains an events Claim whose + # value is JSON object containing the member name + # http://schemas.openid.net/event/backchannel-logout. + # 6. Verify that the Logout Token does not contain a nonce Claim. + # This is all verified by the LogoutToken claims class, so at this + # point the `sid` claim exists and is a string. + sid: str = claims.get("sid") + + # If the `sub` claim was included in the logout token, we check that it matches + # that it matches the right user. We can have cases where the `sub` claim is not + # the ID saved in database, so we let admins disable this check in config. + sub: Optional[str] = claims.get("sub") + expected_user_id: Optional[str] = None + if sub is not None and not self._config.backchannel_logout_ignore_sub: + expected_user_id = await self._store.get_user_by_external_id( + self.idp_id, sub + ) + + # Invalidate any running user-mapping sessions, in-flight login tokens and + # active devices + await self._sso_handler.revoke_sessions_for_provider_session_id( + auth_provider_id=self.idp_id, + auth_provider_session_id=sid, + expected_user_id=expected_user_id, + ) + + request.setResponseCode(200) + request.setHeader(b"Cache-Control", b"no-cache, no-store") + request.setHeader(b"Pragma", b"no-cache") + finish_request(request) + + +class LogoutToken(JWTClaims): + """ + Holds and verify claims of a logout token, as per + https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken + """ + + REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"] + + def validate(self, now: Optional[int] = None, leeway: int = 0) -> None: + """Validate everything in claims payload.""" + super().validate(now, leeway) + self.validate_sid() + self.validate_events() + self.validate_nonce() + + def validate_sid(self) -> None: + """Ensure the sid claim is present""" + sid = self.get("sid") + if not sid: + raise MissingClaimError("sid") + + if not isinstance(sid, str): + raise InvalidClaimError("sid") + + def validate_nonce(self) -> None: + """Ensure the nonce claim is absent""" + if "nonce" in self: + raise InvalidClaimError("nonce") + + def validate_events(self) -> None: + """Ensure the events claim is present and with the right value""" + events = self.get("events") + if not events: + raise MissingClaimError("events") + + if not isinstance(events, dict): + raise InvalidClaimError("events") + + if "http://schemas.openid.net/event/backchannel-logout" not in events: + raise InvalidClaimError("events") + # number of seconds a newly-generated client secret should be valid for CLIENT_SECRET_VALIDITY_SECONDS = 3600 @@ -1105,6 +1423,7 @@ class JwtClientSecret: logger.info( "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload ) + jwt = JsonWebToken(header["alg"]) self._cached_secret = jwt.encode(header, payload, self._key.key) self._cached_secret_replacement_time = ( expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS @@ -1116,12 +1435,10 @@ class UserAttributeDict(TypedDict): localpart: Optional[str] confirm_localpart: bool display_name: Optional[str] + picture: Optional[str] # may be omitted by older `OidcMappingProviders` emails: List[str] -C = TypeVar("C") - - class OidcMappingProvider(Generic[C]): """A mapping provider maps a UserInfo object to user attributes. @@ -1204,6 +1521,7 @@ env.filters.update( @attr.s(slots=True, frozen=True, auto_attribs=True) class JinjaOidcMappingConfig: subject_claim: str + picture_claim: str localpart_template: Optional[Template] display_name_template: Optional[Template] email_template: Optional[Template] @@ -1223,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @staticmethod def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") + picture_claim = config.get("picture_claim", "picture") def parse_template_config(option_name: str) -> Optional[Template]: if option_name not in config: @@ -1256,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): return JinjaOidcMappingConfig( subject_claim=subject_claim, + picture_claim=picture_claim, localpart_template=localpart_template, display_name_template=display_name_template, email_template=email_template, @@ -1295,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): if email: emails.append(email) + picture = userinfo.get("picture") + return UserAttributeDict( localpart=localpart, display_name=display_name, emails=emails, + picture=picture, confirm_localpart=self._config.confirm_localpart, ) |