diff options
Diffstat (limited to 'synapse/handlers/oidc.py')
-rw-r--r-- | synapse/handlers/oidc.py | 1384 |
1 files changed, 1384 insertions, 0 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py new file mode 100644 index 0000000000..45514be50f --- /dev/null +++ b/synapse/handlers/oidc.py @@ -0,0 +1,1384 @@ +# Copyright 2020 Quentin Gliech +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import logging +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union +from urllib.parse import urlencode + +import attr +import pymacaroons +from authlib.common.security import generate_token +from authlib.jose import JsonWebToken, jwt +from authlib.oauth2.auth import ClientAuth +from authlib.oauth2.rfc6749.parameters import prepare_grant_uri +from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo +from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url +from jinja2 import Environment, Template +from pymacaroons.exceptions import ( + MacaroonDeserializationException, + MacaroonInitException, + MacaroonInvalidSignatureException, +) +from typing_extensions import TypedDict + +from twisted.web.client import readBody +from twisted.web.http_headers import Headers + +from synapse.config import ConfigError +from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig +from synapse.handlers.sso import MappingException, UserAttributes +from synapse.http.site import SynapseRequest +from synapse.logging.context import make_deferred_yieldable +from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart +from synapse.util import Clock, json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# we want the cookie to be returned to us even when the request is the POSTed +# result of a form on another domain, as is used with `response_mode=form_post`. +# +# Modern browsers will not do so unless we set SameSite=None; however *older* +# browsers (including all versions of Safari on iOS 12?) don't support +# SameSite=None, and interpret it as SameSite=Strict: +# https://bugs.webkit.org/show_bug.cgi?id=198181 +# +# As a rather painful workaround, we set *two* cookies, one with SameSite=None +# and one with no SameSite, in the hope that at least one of them will get +# back to us. +# +# Secure is necessary for SameSite=None (and, empirically, also breaks things +# on iOS 12.) +# +# Here we have the names of the cookies, and the options we use to set them. +_SESSION_COOKIES = [ + (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"), + (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"), +] + +#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and +#: OpenID.Core sec 3.1.3.3. +Token = TypedDict( + "Token", + { + "access_token": str, + "token_type": str, + "id_token": Optional[str], + "refresh_token": Optional[str], + "expires_in": int, + "scope": Optional[str], + }, +) + +#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but +#: there is no real point of doing this in our case. +JWK = Dict[str, str] + +#: A JWK Set, as per RFC7517 sec 5. +JWKS = TypedDict("JWKS", {"keys": List[JWK]}) + + +class OidcHandler: + """Handles requests related to the OpenID Connect login flow.""" + + def __init__(self, hs: "HomeServer"): + self._sso_handler = hs.get_sso_handler() + + provider_confs = hs.config.oidc.oidc_providers + # we should not have been instantiated if there is no configured provider. + assert provider_confs + + self._token_generator = OidcSessionTokenGenerator(hs) + self._providers = { + p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs + } # type: Dict[str, OidcProvider] + + async def load_metadata(self) -> None: + """Validate the config and load the metadata from the remote endpoint. + + Called at startup to ensure we have everything we need. + """ + for idp_id, p in self._providers.items(): + try: + await p.load_metadata() + await p.load_jwks() + except Exception as e: + raise Exception( + "Error while initialising OIDC provider %r" % (idp_id,) + ) from e + + async def handle_oidc_callback(self, request: SynapseRequest) -> None: + """Handle an incoming request to /_synapse/client/oidc/callback + + Since we might want to display OIDC-related errors in a user-friendly + way, we don't raise SynapseError from here. Instead, we call + ``self._sso_handler.render_error`` which displays an HTML page for the error. + + Most of the OpenID Connect logic happens here: + + - first, we check if there was any error returned by the provider and + display it + - then we fetch the session cookie, decode and verify it + - the ``state`` query parameter should match with the one stored in the + session cookie + + Once we know the session is legit, we then delegate to the OIDC Provider + implementation, which will exchange the code with the provider and complete the + login/authentication. + + Args: + request: the incoming request from the browser. + """ + # This will always be set by the time Twisted calls us. + assert request.args is not None + + # The provider might redirect with an error. + # In that case, just display it as-is. + if b"error" in request.args: + # error response from the auth server. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2.1 + # https://openid.net/specs/openid-connect-core-1_0.html#AuthError + error = request.args[b"error"][0].decode() + description = request.args.get(b"error_description", [b""])[0].decode() + + # Most of the errors returned by the provider could be due by + # either the provider misbehaving or Synapse being misconfigured. + # The only exception of that is "access_denied", where the user + # probably cancelled the login flow. In other cases, log those errors. + logger.log( + logging.INFO if error == "access_denied" else logging.ERROR, + "Received OIDC callback with error: %s %s", + error, + description, + ) + + self._sso_handler.render_error(request, error, description) + return + + # otherwise, it is presumably a successful response. see: + # https://tools.ietf.org/html/rfc6749#section-4.1.2 + + # Fetch the session cookie. See the comments on SESSION_COOKIES for why there + # are two. + + for cookie_name, _ in _SESSION_COOKIES: + session = request.getCookie(cookie_name) # type: Optional[bytes] + if session is not None: + break + else: + logger.info("Received OIDC callback, with no session cookie") + self._sso_handler.render_error( + request, "missing_session", "No session cookie found" + ) + return + + # Remove the cookies. There is a good chance that if the callback failed + # once, it will fail next time and the code will already be exchanged. + # Removing the cookies early avoids spamming the provider with token requests. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s" + % (cookie_name, options) + ) + + # Check for the state query parameter + if b"state" not in request.args: + logger.info("Received OIDC callback, with no state parameter") + self._sso_handler.render_error( + request, "invalid_request", "State parameter is missing" + ) + return + + state = request.args[b"state"][0].decode() + + # Deserialize the session token and verify it. + try: + session_data = self._token_generator.verify_oidc_session_token( + session, state + ) + except (MacaroonInitException, MacaroonDeserializationException, KeyError) as e: + logger.exception("Invalid session for OIDC callback") + self._sso_handler.render_error(request, "invalid_session", str(e)) + return + except MacaroonInvalidSignatureException as e: + logger.exception("Could not verify session for OIDC callback") + self._sso_handler.render_error(request, "mismatching_session", str(e)) + return + + logger.info("Received OIDC callback for IdP %s", session_data.idp_id) + + oidc_provider = self._providers.get(session_data.idp_id) + if not oidc_provider: + logger.error("OIDC session uses unknown IdP %r", oidc_provider) + self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP") + return + + if b"code" not in request.args: + logger.info("Code parameter is missing") + self._sso_handler.render_error( + request, "invalid_request", "Code parameter is missing" + ) + return + + code = request.args[b"code"][0].decode() + + await oidc_provider.handle_oidc_callback(request, session_data, code) + + +class OidcError(Exception): + """Used to catch errors when calling the token_endpoint""" + + def __init__(self, error, error_description=None): + self.error = error + self.error_description = error_description + + def __str__(self): + if self.error_description: + return "{}: {}".format(self.error, self.error_description) + return self.error + + +class OidcProvider: + """Wraps the config for a single OIDC IdentityProvider + + Provides methods for handling redirect requests and callbacks via that particular + IdP. + """ + + def __init__( + self, + hs: "HomeServer", + token_generator: "OidcSessionTokenGenerator", + provider: OidcProviderConfig, + ): + self._store = hs.get_datastore() + + self._token_generator = token_generator + + self._config = provider + self._callback_url = hs.config.oidc_callback_url # type: str + + self._oidc_attribute_requirements = provider.attribute_requirements + self._scopes = provider.scopes + self._user_profile_method = provider.user_profile_method + + client_secret = None # type: Union[None, str, JwtClientSecret] + if provider.client_secret: + client_secret = provider.client_secret + elif provider.client_secret_jwt_key: + client_secret = JwtClientSecret( + provider.client_secret_jwt_key, + provider.client_id, + provider.issuer, + hs.get_clock(), + ) + + self._client_auth = ClientAuth( + provider.client_id, + client_secret, + provider.client_auth_method, + ) # type: ClientAuth + self._client_auth_method = provider.client_auth_method + + # cache of metadata for the identity provider (endpoint uris, mostly). This is + # loaded on-demand from the discovery endpoint (if discovery is enabled), with + # possible overrides from the config. Access via `load_metadata`. + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + # cache of JWKs used by the identity provider to sign tokens. Loaded on demand + # from the IdP's jwks_uri, if required. + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + + self._user_mapping_provider = provider.user_mapping_provider_class( + provider.user_mapping_provider_config + ) + self._skip_verification = provider.skip_verification + self._allow_existing_users = provider.allow_existing_users + + self._http_client = hs.get_proxied_http_client() + self._server_name = hs.config.server_name # type: str + + # identifier for the external_ids table + self.idp_id = provider.idp_id + + # user-facing name of this auth provider + self.idp_name = provider.idp_name + + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon + + # optional brand identifier for this auth provider + self.idp_brand = provider.idp_brand + + # Optional brand identifier for the unstable API (see MSC2858). + self.unstable_idp_brand = provider.unstable_idp_brand + + self._sso_handler = hs.get_sso_handler() + + self._sso_handler.register_identity_provider(self) + + def _validate_metadata(self, m: OpenIDProviderMetadata) -> None: + """Verifies the provider metadata. + + This checks the validity of the currently loaded provider. Not + everything is checked, only: + + - ``issuer`` + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``response_types_supported`` (checks if "code" is in it) + - ``jwks_uri`` + + Raises: + ValueError: if something in the provider is not valid + """ + # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin) + if self._skip_verification is True: + return + + m.validate_issuer() + m.validate_authorization_endpoint() + m.validate_token_endpoint() + + if m.get("token_endpoint_auth_methods_supported") is not None: + m.validate_token_endpoint_auth_methods_supported() + if ( + self._client_auth_method + not in m["token_endpoint_auth_methods_supported"] + ): + raise ValueError( + '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format( + auth_method=self._client_auth_method, + supported=m["token_endpoint_auth_methods_supported"], + ) + ) + + if m.get("response_types_supported") is not None: + m.validate_response_types_supported() + + if "code" not in m["response_types_supported"]: + raise ValueError( + '"code" not in "response_types_supported" (%r)' + % (m["response_types_supported"],) + ) + + # Ensure there's a userinfo endpoint to fetch from if it is required. + if self._uses_userinfo: + if m.get("userinfo_endpoint") is None: + raise ValueError( + 'provider has no "userinfo_endpoint", even though it is required' + ) + else: + # If we're not using userinfo, we need a valid jwks to validate the ID token + m.validate_jwks_uri() + + @property + def _uses_userinfo(self) -> bool: + """Returns True if the ``userinfo_endpoint`` should be used. + + This is based on the requested scopes: if the scopes include + ``openid``, the provider should give use an ID token containing the + user information. If not, we should fetch them using the + ``access_token`` with the ``userinfo_endpoint``. + """ + + return ( + "openid" not in self._scopes + or self._user_profile_method == "userinfo_endpoint" + ) + + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: + """Return the provider metadata. + + If this is the first call, the metadata is built from the config and from the + metadata discovery endpoint (if enabled), and then validated. If the metadata + is successfully validated, it is then cached for future use. + + Args: + force: If true, any cached metadata is discarded to force a reload. + + Raises: + ValueError: if something in the provider is not valid + + Returns: + The provider's metadata. + """ + if force: + # reset the cached call to ensure we get a new result + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + return await self._provider_metadata.get() + + async def _load_metadata(self) -> OpenIDProviderMetadata: + # start out with just the issuer (unlike the other settings, discovered issuer + # takes precedence over configured issuer, because configured issuer is + # required for discovery to take place.) + # + metadata = OpenIDProviderMetadata(issuer=self._config.issuer) + + # load any data from the discovery endpoint, if enabled + if self._config.discover: + url = get_well_known_url(self._config.issuer, external=True) + metadata_response = await self._http_client.get_json(url) + metadata.update(metadata_response) + + # override any discovered data with any settings in our config + if self._config.authorization_endpoint: + metadata["authorization_endpoint"] = self._config.authorization_endpoint + + if self._config.token_endpoint: + metadata["token_endpoint"] = self._config.token_endpoint + + if self._config.userinfo_endpoint: + metadata["userinfo_endpoint"] = self._config.userinfo_endpoint + + if self._config.jwks_uri: + metadata["jwks_uri"] = self._config.jwks_uri + + self._validate_metadata(metadata) + + return metadata + + async def load_jwks(self, force: bool = False) -> JWKS: + """Load the JSON Web Key Set used to sign ID tokens. + + If we're not using the ``userinfo_endpoint``, user infos are extracted + from the ID token, which is a JWT signed by keys given by the provider. + The keys are then cached. + + Args: + force: Force reloading the keys. + + Returns: + The key set + + Looks like this:: + + { + 'keys': [ + { + 'kid': 'abcdef', + 'kty': 'RSA', + 'alg': 'RS256', + 'use': 'sig', + 'e': 'XXXX', + 'n': 'XXXX', + } + ] + } + """ + if force: + # reset the cached call to ensure we get a new result + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + return await self._jwks.get() + + async def _load_jwks(self) -> JWKS: + if self._uses_userinfo: + # We're not using jwt signing, return an empty jwk set + return {"keys": []} + + metadata = await self.load_metadata() + + # Load the JWKS using the `jwks_uri` metadata. + uri = metadata.get("jwks_uri") + if not uri: + # this should be unreachable: load_metadata validates that + # there is a jwks_uri in the metadata if _uses_userinfo is unset + raise RuntimeError('Missing "jwks_uri" in metadata') + + jwk_set = await self._http_client.get_json(uri) + + return jwk_set + + async def _exchange_code(self, code: str) -> Token: + """Exchange an authorization code for a token. + + This calls the ``token_endpoint`` with the authorization code we + received in the callback to exchange it for a token. The call uses the + ``ClientAuth`` to authenticate with the client with its ID and secret. + + See: + https://tools.ietf.org/html/rfc6749#section-3.2 + https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + + Args: + code: The authorization code we got from the callback. + + Returns: + A dict containing various tokens. + + May look like this:: + + { + 'token_type': 'bearer', + 'access_token': 'abcdef', + 'expires_in': 3599, + 'id_token': 'ghijkl', + 'refresh_token': 'mnopqr', + } + + Raises: + OidcError: when the ``token_endpoint`` returned an error. + """ + metadata = await self.load_metadata() + token_endpoint = metadata.get("token_endpoint") + raw_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": self._http_client.user_agent, + "Accept": "application/json", + } + + args = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self._callback_url, + } + body = urlencode(args, True) + + # Fill the body/headers with credentials + uri, raw_headers, body = self._client_auth.prepare( + method="POST", uri=token_endpoint, headers=raw_headers, body=body + ) + headers = Headers({k: [v] for (k, v) in raw_headers.items()}) + + # Do the actual request + # We're not using the SimpleHttpClient util methods as we don't want to + # check the HTTP status code and we do the body encoding ourself. + response = await self._http_client.request( + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, + ) + + # This is used in multiple error messages below + status = "{code} {phrase}".format( + code=response.code, phrase=response.phrase.decode("utf-8") + ) + + resp_body = await make_deferred_yieldable(readBody(response)) + + if response.code >= 500: + # In case of a server error, we should first try to decode the body + # and check for an error field. If not, we respond with a generic + # error message. + try: + resp = json_decoder.decode(resp_body.decode("utf-8")) + error = resp["error"] + description = resp.get("error_description", error) + except (ValueError, KeyError): + # Catch ValueError for the JSON decoding and KeyError for the "error" field + error = "server_error" + description = ( + ( + 'Authorization server responded with a "{status}" error ' + "while exchanging the authorization code." + ).format(status=status), + ) + + raise OidcError(error, description) + + # Since it is a not a 5xx code, body should be a valid JSON. It will + # raise if not. + resp = json_decoder.decode(resp_body.decode("utf-8")) + + if "error" in resp: + error = resp["error"] + # In case the authorization server responded with an error field, + # it should be a 4xx code. If not, warn about it but don't do + # anything special and report the original error message. + if response.code < 400: + logger.debug( + "Invalid response from the authorization server: " + 'responded with a "{status}" ' + "but body has an error field: {error!r}".format( + status=status, error=resp["error"] + ) + ) + + description = resp.get("error_description", error) + raise OidcError(error, description) + + # Now, this should not be an error. According to RFC6749 sec 5.1, it + # should be a 200 code. We're a bit more flexible than that, and will + # only throw on a 4xx code. + if response.code >= 400: + description = ( + 'Authorization server responded with a "{status}" error ' + 'but did not include an "error" field in its response.'.format( + status=status + ) + ) + logger.warning(description) + # Body was still valid JSON. Might be useful to log it for debugging. + logger.warning("Code exchange response: {resp!r}".format(resp=resp)) + raise OidcError("server_error", description) + + return resp + + async def _fetch_userinfo(self, token: Token) -> UserInfo: + """Fetch user information from the ``userinfo_endpoint``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``access_token`` field. + + Returns: + UserInfo: an object representing the user. + """ + logger.debug("Using the OAuth2 access_token to request userinfo") + metadata = await self.load_metadata() + + resp = await self._http_client.get_json( + metadata["userinfo_endpoint"], + headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, + ) + + logger.debug("Retrieved user info from userinfo endpoint: %r", resp) + + return UserInfo(resp) + + async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: + """Return an instance of UserInfo from token's ``id_token``. + + Args: + token: the token given by the ``token_endpoint``. + Must include an ``id_token`` field. + nonce: the nonce value originally sent in the initial authorization + request. This value should match the one inside the token. + + Returns: + An object representing the user. + """ + metadata = await self.load_metadata() + claims_params = { + "nonce": nonce, + "client_id": self._client_auth.client_id, + } + if "access_token" in token: + # If we got an `access_token`, there should be an `at_hash` claim + # in the `id_token` that we can check against. + claims_params["access_token"] = token["access_token"] + claims_cls = CodeIDToken + else: + claims_cls = ImplicitIDToken + + alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) + jwt = JsonWebToken(alg_values) + + claim_options = {"iss": {"values": [metadata["issuer"]]}} + + id_token = token["id_token"] + logger.debug("Attempting to decode JWT id_token %r", id_token) + + # Try to decode the keys in cache first, then retry by forcing the keys + # to be reloaded + jwk_set = await self.load_jwks() + try: + claims = jwt.decode( + id_token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + except ValueError: + logger.info("Reloading JWKS after decode error") + jwk_set = await self.load_jwks(force=True) # try reloading the jwks + claims = jwt.decode( + id_token, + key=jwk_set, + claims_cls=claims_cls, + claims_options=claim_options, + claims_params=claims_params, + ) + + logger.debug("Decoded id_token JWT %r; validating", claims) + + claims.validate(leeway=120) # allows 2 min of clock skew + return UserInfo(claims) + + async def handle_redirect_request( + self, + request: SynapseRequest, + client_redirect_url: Optional[bytes], + ui_auth_session_id: Optional[str] = None, + ) -> str: + """Handle an incoming request to /login/sso/redirect + + It returns a redirect to the authorization endpoint with a few + parameters: + + - ``client_id``: the client ID set in ``oidc_config.client_id`` + - ``response_type``: ``code`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback`` + - ``scope``: the list of scopes set in ``oidc_config.scopes`` + - ``state``: a random string + - ``nonce``: a random string + + In addition generating a redirect URL, we are setting a cookie with + a signed macaroon token containing the state, the nonce and the + client_redirect_url params. Those are then checked when the client + comes back from the provider. + + Args: + request: the incoming request from the browser. + We'll respond to it with a redirect and a cookie. + client_redirect_url: the URL that we should redirect the client to + when everything is done (or None for UI Auth) + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + + Returns: + The redirect URL to the authorization endpoint. + + """ + + state = generate_token() + nonce = generate_token() + + if not client_redirect_url: + client_redirect_url = b"" + + cookie = self._token_generator.generate_oidc_session_token( + state=state, + session_data=OidcSessionData( + idp_id=self.idp_id, + nonce=nonce, + client_redirect_url=client_redirect_url.decode(), + ui_auth_session_id=ui_auth_session_id or "", + ), + ) + + # Set the cookies. See the comments on _SESSION_COOKIES for why there are two. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=%s; Max-Age=3600; %s" + % (cookie_name, cookie.encode("utf-8"), options) + ) + + metadata = await self.load_metadata() + authorization_endpoint = metadata.get("authorization_endpoint") + return prepare_grant_uri( + authorization_endpoint, + client_id=self._client_auth.client_id, + response_type="code", + redirect_uri=self._callback_url, + scope=self._scopes, + state=state, + nonce=nonce, + ) + + async def handle_oidc_callback( + self, request: SynapseRequest, session_data: "OidcSessionData", code: str + ) -> None: + """Handle an incoming request to /_synapse/client/oidc/callback + + By this time we have already validated the session on the synapse side, and + now need to do the provider-specific operations. This includes: + + - exchange the code with the provider using the ``token_endpoint`` (see + ``_exchange_code``) + - once we have the token, use it to either extract the UserInfo from + the ``id_token`` (``_parse_id_token``), or use the ``access_token`` + to fetch UserInfo from the ``userinfo_endpoint`` + (``_fetch_userinfo``) + - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and + finish the login + + Args: + request: the incoming request from the browser. + session_data: the session data, extracted from our cookie + code: The authorization code we got from the callback. + """ + # Exchange the code with the provider + try: + logger.debug("Exchanging OAuth2 code for a token") + token = await self._exchange_code(code) + except OidcError as e: + logger.exception("Could not exchange OAuth2 code") + self._sso_handler.render_error(request, e.error, e.error_description) + return + + logger.debug("Successfully obtained OAuth2 token data: %r", token) + + # Now that we have a token, get the userinfo, either by decoding the + # `id_token` or by fetching the `userinfo_endpoint`. + if self._uses_userinfo: + try: + userinfo = await self._fetch_userinfo(token) + except Exception as e: + logger.exception("Could not fetch userinfo") + self._sso_handler.render_error(request, "fetch_error", str(e)) + return + else: + try: + userinfo = await self._parse_id_token(token, nonce=session_data.nonce) + except Exception as e: + logger.exception("Invalid id_token") + self._sso_handler.render_error(request, "invalid_token", str(e)) + return + + # first check if we're doing a UIA + if session_data.ui_auth_session_id: + try: + remote_user_id = self._remote_id_from_userinfo(userinfo) + except Exception as e: + logger.exception("Could not extract remote user id") + self._sso_handler.render_error(request, "mapping_error", str(e)) + return + + return await self._sso_handler.complete_sso_ui_auth_request( + self.idp_id, remote_user_id, session_data.ui_auth_session_id, request + ) + + # otherwise, it's a login + logger.debug("Userinfo for OIDC login: %s", userinfo) + + # Ensure that the attributes of the logged in user meet the required + # attributes by checking the userinfo against attribute_requirements + # In order to deal with the fact that OIDC userinfo can contain many + # types of data, we wrap non-list values in lists. + if not self._sso_handler.check_required_attributes( + request, + {k: v if isinstance(v, list) else [v] for k, v in userinfo.items()}, + self._oidc_attribute_requirements, + ): + return + + # Call the mapper to register/login the user + try: + await self._complete_oidc_login( + userinfo, token, request, session_data.client_redirect_url + ) + except MappingException as e: + logger.exception("Could not map user") + self._sso_handler.render_error(request, "mapping_error", str(e)) + + async def _complete_oidc_login( + self, + userinfo: UserInfo, + token: Token, + request: SynapseRequest, + client_redirect_url: str, + ) -> None: + """Given a UserInfo response, complete the login flow + + UserInfo should have a claim that uniquely identifies users. This claim + is usually `sub`, but can be configured with `oidc_config.subject_claim`. + It is then used as an `external_id`. + + If we don't find the user that way, we should register the user, + mapping the localpart and the display name from the UserInfo. + + If a user already exists with the mxid we've mapped and allow_existing_users + is disabled, raise an exception. + + Otherwise, render a redirect back to the client_redirect_url with a loginToken. + + Args: + userinfo: an object representing the user + token: a dict with the tokens obtained from the provider + request: The request to respond to + client_redirect_url: The redirect URL passed in by the client. + + Raises: + MappingException: if there was an error while mapping some properties + """ + try: + remote_user_id = self._remote_id_from_userinfo(userinfo) + except Exception as e: + raise MappingException( + "Failed to extract subject from OIDC response: %s" % (e,) + ) + + # Older mapping providers don't accept the `failures` argument, so we + # try and detect support. + mapper_signature = inspect.signature( + self._user_mapping_provider.map_user_attributes + ) + supports_failures = "failures" in mapper_signature.parameters + + async def oidc_response_to_user_attributes(failures: int) -> UserAttributes: + """ + Call the mapping provider to map the OIDC userinfo and token to user attributes. + + This is backwards compatibility for abstraction for the SSO handler. + """ + if supports_failures: + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token, failures + ) + else: + # If the mapping provider does not support processing failures, + # do not continually generate the same Matrix ID since it will + # continue to already be in use. Note that the error raised is + # arbitrary and will get turned into a MappingException. + if failures: + raise MappingException( + "Mapping provider does not support de-duplicating Matrix IDs" + ) + + attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore + userinfo, token + ) + + return UserAttributes(**attributes) + + async def grandfather_existing_users() -> Optional[str]: + if self._allow_existing_users: + # If allowing existing users we want to generate a single localpart + # and attempt to match it. + attributes = await oidc_response_to_user_attributes(failures=0) + + user_id = UserID(attributes.localpart, self._server_name).to_string() + users = await self._store.get_users_by_id_case_insensitive(user_id) + if users: + # If an existing matrix ID is returned, then use it. + if len(users) == 1: + previously_registered_user_id = next(iter(users)) + elif user_id in users: + previously_registered_user_id = user_id + else: + # Do not attempt to continue generating Matrix IDs. + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, users + ) + ) + + return previously_registered_user_id + + return None + + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + + await self._sso_handler.complete_sso_login_request( + self.idp_id, + remote_user_id, + request, + client_redirect_url, + oidc_response_to_user_attributes, + grandfather_existing_users, + extra_attributes, + ) + + def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: + """Extract the unique remote id from an OIDC UserInfo block + + Args: + userinfo: An object representing the user given by the OIDC provider + Returns: + remote user id + """ + remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo) + # Some OIDC providers use integer IDs, but Synapse expects external IDs + # to be strings. + return str(remote_user_id) + + +# number of seconds a newly-generated client secret should be valid for +CLIENT_SECRET_VALIDITY_SECONDS = 3600 + +# minimum remaining validity on a client secret before we should generate a new one +CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600 + + +class JwtClientSecret: + """A class which generates a new client secret on demand, based on a JWK + + This implementation is designed to comply with the requirements for Apple Sign in: + https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048 + + It looks like those requirements are based on https://tools.ietf.org/html/rfc7523, + but it's worth noting that we still put the generated secret in the "client_secret" + field (or rather, whereever client_auth_method puts it) rather than in a + client_assertion field in the body as that RFC seems to require. + """ + + def __init__( + self, + key: OidcProviderClientSecretJwtKey, + oauth_client_id: str, + oauth_issuer: str, + clock: Clock, + ): + self._key = key + self._oauth_client_id = oauth_client_id + self._oauth_issuer = oauth_issuer + self._clock = clock + self._cached_secret = b"" + self._cached_secret_replacement_time = 0 + + def __str__(self): + # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls + # encode_client_secret_basic, which calls "{}".format(secret), which ends up + # here. + return self._get_secret().decode("ascii") + + def __bytes__(self): + # if client_auth_method is client_secret_post, then ClientAuth.prepare calls + # encode_client_secret_post, which ends up here. + return self._get_secret() + + def _get_secret(self) -> bytes: + now = self._clock.time() + + # if we have enough validity on our existing secret, use it + if now < self._cached_secret_replacement_time: + return self._cached_secret + + issued_at = int(now) + expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS + + # we copy the configured header because jwt.encode modifies it. + header = dict(self._key.jwt_header) + + # see https://tools.ietf.org/html/rfc7523#section-3 + payload = { + "sub": self._oauth_client_id, + "aud": self._oauth_issuer, + "iat": issued_at, + "exp": expires_at, + **self._key.jwt_payload, + } + logger.info( + "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload + ) + self._cached_secret = jwt.encode(header, payload, self._key.key) + self._cached_secret_replacement_time = ( + expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS + ) + return self._cached_secret + + +class OidcSessionTokenGenerator: + """Methods for generating and checking OIDC Session cookies.""" + + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._server_name = hs.hostname + self._macaroon_secret_key = hs.config.key.macaroon_secret_key + + def generate_oidc_session_token( + self, + state: str, + session_data: "OidcSessionData", + duration_in_ms: int = (60 * 60 * 1000), + ) -> str: + """Generates a signed token storing data about an OIDC session. + + When Synapse initiates an authorization flow, it creates a random state + and a random nonce. Those parameters are given to the provider and + should be verified when the client comes back from the provider. + It is also used to store the client_redirect_url, which is used to + complete the SSO login flow. + + Args: + state: The ``state`` parameter passed to the OIDC provider. + session_data: data to include in the session token. + duration_in_ms: An optional duration for the token in milliseconds. + Defaults to an hour. + + Returns: + A signed macaroon token with the session information. + """ + macaroon = pymacaroons.Macaroon( + location=self._server_name, + identifier="key", + key=self._macaroon_secret_key, + ) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("type = session") + macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,)) + macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,)) + macaroon.add_first_party_caveat( + "client_redirect_url = %s" % (session_data.client_redirect_url,) + ) + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) + ) + now = self._clock.time_msec() + expiry = now + duration_in_ms + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + + return macaroon.serialize() + + def verify_oidc_session_token( + self, session: bytes, state: str + ) -> "OidcSessionData": + """Verifies and extract an OIDC session token. + + This verifies that a given session token was issued by this homeserver + and extract the nonce and client_redirect_url caveats. + + Args: + session: The session token to verify + state: The state the OIDC provider gave back + + Returns: + The data extracted from the session cookie + + Raises: + KeyError if an expected caveat is missing from the macaroon. + """ + macaroon = pymacaroons.Macaroon.deserialize(session) + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = session") + v.satisfy_exact("state = %s" % (state,)) + v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("idp_id = ")) + v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) + satisfy_expiry(v, self._clock.time_msec) + + v.verify(macaroon, self._macaroon_secret_key) + + # Extract the session data from the token. + nonce = get_value_from_macaroon(macaroon, "nonce") + idp_id = get_value_from_macaroon(macaroon, "idp_id") + client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url") + ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id") + return OidcSessionData( + nonce=nonce, + idp_id=idp_id, + client_redirect_url=client_redirect_url, + ui_auth_session_id=ui_auth_session_id, + ) + + +@attr.s(frozen=True, slots=True) +class OidcSessionData: + """The attributes which are stored in a OIDC session cookie""" + + # the Identity Provider being used + idp_id = attr.ib(type=str) + + # The `nonce` parameter passed to the OIDC provider. + nonce = attr.ib(type=str) + + # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) + client_redirect_url = attr.ib(type=str) + + # The session ID of the ongoing UI Auth ("" if this is a login) + ui_auth_session_id = attr.ib(type=str) + + +UserAttributeDict = TypedDict( + "UserAttributeDict", + {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, +) +C = TypeVar("C") + + +class OidcMappingProvider(Generic[C]): + """A mapping provider maps a UserInfo object to user attributes. + + It should provide the API described by this class. + """ + + def __init__(self, config: C): + """ + Args: + config: A custom config object from this module, parsed by ``parse_config()`` + """ + + @staticmethod + def parse_config(config: dict) -> C: + """Parse the dict provided by the homeserver's config + + Args: + config: A dictionary containing configuration options for this provider + + Returns: + A custom config object for this module + """ + raise NotImplementedError() + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + """Get a unique user ID for this user. + + Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object. + + Args: + userinfo: An object representing the user given by the OIDC provider + + Returns: + A unique user ID + """ + raise NotImplementedError() + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: + """Map a `UserInfo` object into user attributes. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + failures: How many times a call to this function with this + UserInfo has resulted in a failure. + + Returns: + A dict containing the ``localpart`` and (optionally) the ``display_name`` + """ + raise NotImplementedError() + + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + """Map a `UserInfo` object into additional attributes passed to the client during login. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing additional attributes. Must be JSON serializable. + """ + return {} + + +# Used to clear out "None" values in templates +def jinja_finalize(thing): + return thing if thing is not None else "" + + +env = Environment(finalize=jinja_finalize) + + +@attr.s(slots=True, frozen=True) +class JinjaOidcMappingConfig: + subject_claim = attr.ib(type=str) + localpart_template = attr.ib(type=Optional[Template]) + display_name_template = attr.ib(type=Optional[Template]) + email_template = attr.ib(type=Optional[Template]) + extra_attributes = attr.ib(type=Dict[str, Template]) + + +class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): + """An implementation of a mapping provider based on Jinja templates. + + This is the default mapping provider. + """ + + def __init__(self, config: JinjaOidcMappingConfig): + self._config = config + + @staticmethod + def parse_config(config: dict) -> JinjaOidcMappingConfig: + subject_claim = config.get("subject_claim", "sub") + + def parse_template_config(option_name: str) -> Optional[Template]: + if option_name not in config: + return None + try: + return env.from_string(config[option_name]) + except Exception as e: + raise ConfigError("invalid jinja template", path=[option_name]) from e + + localpart_template = parse_template_config("localpart_template") + display_name_template = parse_template_config("display_name_template") + email_template = parse_template_config("email_template") + + extra_attributes = {} # type Dict[str, Template] + if "extra_attributes" in config: + extra_attributes_config = config.get("extra_attributes") or {} + if not isinstance(extra_attributes_config, dict): + raise ConfigError("must be a dict", path=["extra_attributes"]) + + for key, value in extra_attributes_config.items(): + try: + extra_attributes[key] = env.from_string(value) + except Exception as e: + raise ConfigError( + "invalid jinja template", path=["extra_attributes", key] + ) from e + + return JinjaOidcMappingConfig( + subject_claim=subject_claim, + localpart_template=localpart_template, + display_name_template=display_name_template, + email_template=email_template, + extra_attributes=extra_attributes, + ) + + def get_remote_user_id(self, userinfo: UserInfo) -> str: + return userinfo[self._config.subject_claim] + + async def map_user_attributes( + self, userinfo: UserInfo, token: Token, failures: int + ) -> UserAttributeDict: + localpart = None + + if self._config.localpart_template: + localpart = self._config.localpart_template.render(user=userinfo).strip() + + # Ensure only valid characters are included in the MXID. + localpart = map_username_to_mxid_localpart(localpart) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" + + def render_template_field(template: Optional[Template]) -> Optional[str]: + if template is None: + return None + return template.render(user=userinfo).strip() + + display_name = render_template_field(self._config.display_name_template) + if display_name == "": + display_name = None + + emails = [] # type: List[str] + email = render_template_field(self._config.email_template) + if email: + emails.append(email) + + return UserAttributeDict( + localpart=localpart, display_name=display_name, emails=emails + ) + + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + extras = {} # type: Dict[str, str] + for key, template in self._config.extra_attributes.items(): + try: + extras[key] = template.render(user=userinfo).strip() + except Exception as e: + # Log an error and skip this value (don't break login for this). + logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e)) + return extras |