summary refs log tree commit diff
path: root/synapse/handlers/oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/oidc.py')
-rw-r--r--synapse/handlers/oidc.py381
1 files changed, 345 insertions, 36 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 9759daf043..867973dcca 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -12,14 +12,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import binascii
 import inspect
+import json
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+)
 from urllib.parse import urlencode, urlparse
 
 import attr
+import unpaddedbase64
 from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken, jwt
+from authlib.jose import JsonWebToken, JWTClaims
+from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
 from authlib.oidc.core import CodeIDToken, UserInfo
@@ -35,9 +49,12 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 from twisted.web.http_headers import Headers
 
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
+from synapse.http.server import finish_request
+from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -88,6 +105,8 @@ class Token(TypedDict):
 #: there is no real point of doing this in our case.
 JWK = Dict[str, str]
 
+C = TypeVar("C")
+
 
 #: A JWK Set, as per RFC7517 sec 5.
 class JWKS(TypedDict):
@@ -247,6 +266,80 @@ class OidcHandler:
 
         await oidc_provider.handle_oidc_callback(request, session_data, code)
 
+    async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        This extracts the logout_token from the request and tries to figure out
+        which OpenID Provider it is comming from. This works by matching the iss claim
+        with the issuer and the aud claim with the client_id.
+
+        Since at this point we don't know who signed the JWT, we can't just
+        decode it using authlib since it will always verifies the signature. We
+        have to decode it manually without validating the signature. The actual JWT
+        verification is done in the `OidcProvider.handler_backchannel_logout` method,
+        once we figured out which provider sent the request.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+        logout_token = parse_string(request, "logout_token")
+        if logout_token is None:
+            raise SynapseError(400, "Missing logout_token in request")
+
+        # A JWT looks like this:
+        #    header.payload.signature
+        # where all parts are encoded with urlsafe base64.
+        # The aud and iss claims we care about are in the payload part, which
+        # is a JSON object.
+        try:
+            # By destructuring the list after splitting, we ensure that we have
+            # exactly 3 segments
+            _, payload, _ = logout_token.split(".")
+        except ValueError:
+            raise SynapseError(400, "Invalid logout_token in request")
+
+        try:
+            payload_bytes = unpaddedbase64.decode_base64(payload)
+            claims = json_decoder.decode(payload_bytes.decode("utf-8"))
+        except (json.JSONDecodeError, binascii.Error, UnicodeError):
+            raise SynapseError(400, "Invalid logout_token payload in request")
+
+        try:
+            # Let's extract the iss and aud claims
+            iss = claims["iss"]
+            aud = claims["aud"]
+            # The aud claim can be either a string or a list of string. Here we
+            # normalize it as a list of strings.
+            if isinstance(aud, str):
+                aud = [aud]
+
+            # Check that we have the right types for the aud and the iss claims
+            if not isinstance(iss, str) or not isinstance(aud, list):
+                raise TypeError()
+            for a in aud:
+                if not isinstance(a, str):
+                    raise TypeError()
+
+            # At this point we properly checked both claims types
+            issuer: str = iss
+            audience: List[str] = aud
+        except (TypeError, KeyError):
+            raise SynapseError(400, "Invalid issuer/audience in logout_token")
+
+        # Now that we know the audience and the issuer, we can figure out from
+        # what provider it is coming from
+        oidc_provider: Optional[OidcProvider] = None
+        for provider in self._providers.values():
+            if provider.issuer == issuer and provider.client_id in audience:
+                oidc_provider = provider
+                break
+
+        if oidc_provider is None:
+            raise SynapseError(400, "Could not find the OP that issued this event")
+
+        # Ask the provider to handle the logout request.
+        await oidc_provider.handle_backchannel_logout(request, logout_token)
+
 
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint"""
@@ -342,6 +435,7 @@ class OidcProvider:
         self.idp_brand = provider.idp_brand
 
         self._sso_handler = hs.get_sso_handler()
+        self._device_handler = hs.get_device_handler()
 
         self._sso_handler.register_identity_provider(self)
 
@@ -400,6 +494,41 @@ class OidcProvider:
             # If we're not using userinfo, we need a valid jwks to validate the ID token
             m.validate_jwks_uri()
 
+        if self._config.backchannel_logout_enabled:
+            if not m.get("backchannel_logout_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled for issuer %r"
+                    "but it does not advertise support for it",
+                    self.issuer,
+                )
+
+            elif not m.get("backchannel_logout_session_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled and supported "
+                    "by issuer %r but it might not send a session ID with "
+                    "logout tokens, which is required for the logouts to work",
+                    self.issuer,
+                )
+
+            if not self._config.backchannel_logout_ignore_sub:
+                # If OIDC backchannel logouts are enabled, the provider mapping provider
+                # should use the `sub` claim. We verify that by mapping a dumb user and
+                # see if we get back the sub claim
+                user = UserInfo({"sub": "thisisasubject"})
+                try:
+                    subject = self._user_mapping_provider.get_remote_user_id(user)
+                    if subject != user["sub"]:
+                        raise ValueError("Unexpected subject")
+                except Exception:
+                    logger.warning(
+                        f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
+                        "but it looks like the configured `user_mapping_provider` "
+                        "does not use the `sub` claim as subject. If it is the case, "
+                        "and you want Synapse to ignore the `sub` claim in OIDC "
+                        "Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
+                        "to `true` in the issuer config."
+                    )
+
     @property
     def _uses_userinfo(self) -> bool:
         """Returns True if the ``userinfo_endpoint`` should be used.
@@ -415,6 +544,16 @@ class OidcProvider:
             or self._user_profile_method == "userinfo_endpoint"
         )
 
+    @property
+    def issuer(self) -> str:
+        """The issuer identifying this provider."""
+        return self._config.issuer
+
+    @property
+    def client_id(self) -> str:
+        """The client_id used when interacting with this provider."""
+        return self._config.client_id
+
     async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
         """Return the provider metadata.
 
@@ -662,6 +801,59 @@ class OidcProvider:
 
         return UserInfo(resp)
 
+    async def _verify_jwt(
+        self,
+        alg_values: List[str],
+        token: str,
+        claims_cls: Type[C],
+        claims_options: Optional[dict] = None,
+        claims_params: Optional[dict] = None,
+    ) -> C:
+        """Decode and validate a JWT, re-fetching the JWKS as needed.
+
+        Args:
+            alg_values: list of `alg` values allowed when verifying the JWT.
+            token: the JWT.
+            claims_cls: the JWTClaims class to use to validate the claims.
+            claims_options: dict of options passed to the `claims_cls` constructor.
+            claims_params: dict of params passed to the `claims_cls` constructor.
+
+        Returns:
+            The decoded claims in the JWT.
+        """
+        jwt = JsonWebToken(alg_values)
+
+        logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
+
+        # Try to decode the keys in cache first, then retry by forcing the keys
+        # to be reloaded
+        jwk_set = await self.load_jwks()
+        try:
+            claims = jwt.decode(
+                token,
+                key=jwk_set,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
+                claims_params=claims_params,
+            )
+        except ValueError:
+            logger.info("Reloading JWKS after decode error")
+            jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
+            claims = jwt.decode(
+                token,
+                key=jwk_set,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
+                claims_params=claims_params,
+            )
+
+        logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
+
+        claims.validate(
+            now=self._clock.time(), leeway=120
+        )  # allows 2 min of clock skew
+        return claims
+
     async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
         """Return an instance of UserInfo from token's ``id_token``.
 
@@ -675,13 +867,13 @@ class OidcProvider:
             The decoded claims in the ID token.
         """
         id_token = token.get("id_token")
-        logger.debug("Attempting to decode JWT id_token %r", id_token)
 
         # That has been theoritically been checked by the caller, so even though
         # assertion are not enabled in production, it is mainly here to appease mypy
         assert id_token is not None
 
         metadata = await self.load_metadata()
+
         claims_params = {
             "nonce": nonce,
             "client_id": self._client_auth.client_id,
@@ -691,38 +883,17 @@ class OidcProvider:
             # in the `id_token` that we can check against.
             claims_params["access_token"] = token["access_token"]
 
-        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
-        jwt = JsonWebToken(alg_values)
-
-        claim_options = {"iss": {"values": [metadata["issuer"]]}}
+        claims_options = {"iss": {"values": [metadata["issuer"]]}}
 
-        # Try to decode the keys in cache first, then retry by forcing the keys
-        # to be reloaded
-        jwk_set = await self.load_jwks()
-        try:
-            claims = jwt.decode(
-                id_token,
-                key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
-                claims_params=claims_params,
-            )
-        except ValueError:
-            logger.info("Reloading JWKS after decode error")
-            jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
-            claims = jwt.decode(
-                id_token,
-                key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
-                claims_params=claims_params,
-            )
-
-        logger.debug("Decoded id_token JWT %r; validating", claims)
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
 
-        claims.validate(
-            now=self._clock.time(), leeway=120
-        )  # allows 2 min of clock skew
+        claims = await self._verify_jwt(
+            alg_values=alg_values,
+            token=id_token,
+            claims_cls=CodeIDToken,
+            claims_options=claims_options,
+            claims_params=claims_params,
+        )
 
         return claims
 
@@ -1043,6 +1214,146 @@ class OidcProvider:
         # to be strings.
         return str(remote_user_id)
 
+    async def handle_backchannel_logout(
+        self, request: SynapseRequest, logout_token: str
+    ) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        The OIDC Provider posts a logout token to this endpoint when a user
+        session ends. That token is a JWT signed with the same keys as
+        ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
+        validate the JWT and figure out what session to end.
+
+        Args:
+            request: The request to respond to
+            logout_token: The logout token (a JWT) extracted from the request body
+        """
+        # Back-Channel Logout can be disabled in the config, hence this check.
+        # This is not that important for now since Synapse is registered
+        # manually to the OP, so not specifying the backchannel-logout URI is
+        # as effective than disabling it here. It might make more sense if we
+        # support dynamic registration in Synapse at some point.
+        if not self._config.backchannel_logout_enabled:
+            logger.warning(
+                f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
+            )
+
+            # TODO: this responds with a 400 status code, which is what the OIDC
+            # Back-Channel Logout spec expects, but spec also suggests answering with
+            # a JSON object, with the `error` and `error_description` fields set, which
+            # we are not doing here.
+            # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
+            raise SynapseError(
+                400, "OpenID Connect Back-Channel Logout is disabled for this provider"
+            )
+
+        metadata = await self.load_metadata()
+
+        # As per OIDC Back-Channel Logout 1.0 sec. 2.4:
+        #   A Logout Token MUST be signed and MAY also be encrypted. The same
+        #   keys are used to sign and encrypt Logout Tokens as are used for ID
+        #   Tokens. If the Logout Token is encrypted, it SHOULD replicate the
+        #   iss (issuer) claim in the JWT Header Parameters, as specified in
+        #   Section 5.3 of [JWT].
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+        # As per sec. 2.6:
+        #    3. Validate the iss, aud, and iat Claims in the same way they are
+        #       validated in ID Tokens.
+        # Which means the audience should contain Synapse's client_id and the
+        # issuer should be the IdP issuer
+        claims_options = {
+            "iss": {"values": [metadata["issuer"]]},
+            "aud": {"values": [self.client_id]},
+        }
+
+        try:
+            claims = await self._verify_jwt(
+                alg_values=alg_values,
+                token=logout_token,
+                claims_cls=LogoutToken,
+                claims_options=claims_options,
+            )
+        except JoseError:
+            logger.exception("Invalid logout_token")
+            raise SynapseError(400, "Invalid logout_token")
+
+        # As per sec. 2.6:
+        #    4. Verify that the Logout Token contains a sub Claim, a sid Claim,
+        #       or both.
+        #    5. Verify that the Logout Token contains an events Claim whose
+        #       value is JSON object containing the member name
+        #       http://schemas.openid.net/event/backchannel-logout.
+        #    6. Verify that the Logout Token does not contain a nonce Claim.
+        # This is all verified by the LogoutToken claims class, so at this
+        # point the `sid` claim exists and is a string.
+        sid: str = claims.get("sid")
+
+        # If the `sub` claim was included in the logout token, we check that it matches
+        # that it matches the right user. We can have cases where the `sub` claim is not
+        # the ID saved in database, so we let admins disable this check in config.
+        sub: Optional[str] = claims.get("sub")
+        expected_user_id: Optional[str] = None
+        if sub is not None and not self._config.backchannel_logout_ignore_sub:
+            expected_user_id = await self._store.get_user_by_external_id(
+                self.idp_id, sub
+            )
+
+        # Invalidate any running user-mapping sessions, in-flight login tokens and
+        # active devices
+        await self._sso_handler.revoke_sessions_for_provider_session_id(
+            auth_provider_id=self.idp_id,
+            auth_provider_session_id=sid,
+            expected_user_id=expected_user_id,
+        )
+
+        request.setResponseCode(200)
+        request.setHeader(b"Cache-Control", b"no-cache, no-store")
+        request.setHeader(b"Pragma", b"no-cache")
+        finish_request(request)
+
+
+class LogoutToken(JWTClaims):
+    """
+    Holds and verify claims of a logout token, as per
+    https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
+    """
+
+    REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
+
+    def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
+        """Validate everything in claims payload."""
+        super().validate(now, leeway)
+        self.validate_sid()
+        self.validate_events()
+        self.validate_nonce()
+
+    def validate_sid(self) -> None:
+        """Ensure the sid claim is present"""
+        sid = self.get("sid")
+        if not sid:
+            raise MissingClaimError("sid")
+
+        if not isinstance(sid, str):
+            raise InvalidClaimError("sid")
+
+    def validate_nonce(self) -> None:
+        """Ensure the nonce claim is absent"""
+        if "nonce" in self:
+            raise InvalidClaimError("nonce")
+
+    def validate_events(self) -> None:
+        """Ensure the events claim is present and with the right value"""
+        events = self.get("events")
+        if not events:
+            raise MissingClaimError("events")
+
+        if not isinstance(events, dict):
+            raise InvalidClaimError("events")
+
+        if "http://schemas.openid.net/event/backchannel-logout" not in events:
+            raise InvalidClaimError("events")
+
 
 # number of seconds a newly-generated client secret should be valid for
 CLIENT_SECRET_VALIDITY_SECONDS = 3600
@@ -1112,6 +1423,7 @@ class JwtClientSecret:
         logger.info(
             "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
         )
+        jwt = JsonWebToken(header["alg"])
         self._cached_secret = jwt.encode(header, payload, self._key.key)
         self._cached_secret_replacement_time = (
             expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
@@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict):
     emails: List[str]
 
 
-C = TypeVar("C")
-
-
 class OidcMappingProvider(Generic[C]):
     """A mapping provider maps a UserInfo object to user attributes.