summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11414.feature1
-rw-r--r--docs/openid.md14
-rw-r--r--docs/usage/configuration/config_documentation.md9
-rw-r--r--synapse/config/oidc.py12
-rw-r--r--synapse/handlers/oidc.py381
-rw-r--r--synapse/handlers/sso.py71
-rw-r--r--synapse/rest/synapse/client/oidc/__init__.py4
-rw-r--r--synapse/rest/synapse/client/oidc/backchannel_logout_resource.py35
-rw-r--r--synapse/storage/databases/main/registration.py21
-rw-r--r--tests/rest/client/test_auth.py390
-rw-r--r--tests/rest/client/utils.py55
-rw-r--r--tests/server.py6
-rw-r--r--tests/test_utils/oidc.py27
13 files changed, 960 insertions, 66 deletions
diff --git a/changelog.d/11414.feature b/changelog.d/11414.feature
new file mode 100644
index 0000000000..fc035e50a7
--- /dev/null
+++ b/changelog.d/11414.feature
@@ -0,0 +1 @@
+Support back-channel logouts from OpenID Connect providers.
diff --git a/docs/openid.md b/docs/openid.md
index 87ebea4c29..37c5eb244d 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -49,6 +49,13 @@ setting in your configuration file.
 See the [configuration manual](usage/configuration/config_documentation.md#oidc_providers) for some sample settings, as well as
 the text below for example configurations for specific providers.
 
+## OIDC Back-Channel Logout
+
+Synapse supports receiving [OpenID Connect Back-Channel Logout](https://openid.net/specs/openid-connect-backchannel-1_0.html) notifications.
+
+This lets the OpenID Connect Provider notify Synapse when a user logs out, so that Synapse can end that user session.
+This feature can be enabled by setting the `backchannel_logout_enabled` property to `true` in the provider configuration, and setting the following URL as destination for Back-Channel Logout notifications in your OpenID Connect Provider: `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout`
+
 ## Sample configs
 
 Here are a few configs for providers that should work with Synapse.
@@ -123,6 +130,9 @@ oidc_providers:
 
 [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat.
 
+Keycloak supports OIDC Back-Channel Logout, which sends logout notification to Synapse, so that Synapse users get logged out when they log out from Keycloak.
+This can be optionally enabled by setting `backchannel_logout_enabled` to `true` in the Synapse configuration, and by setting the "Backchannel Logout URL" in Keycloak.
+
 Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to install Keycloak and set up a realm.
 
 1. Click `Clients` in the sidebar and click `Create`
@@ -144,6 +154,8 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to
 | Client Protocol | `openid-connect` |
 | Access Type | `confidential` |
 | Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` |
+| Backchannel Logout URL (optional) | `[synapse public baseurl]/_synapse/client/oidc/backchannel_logout` |
+| Backchannel Logout Session Required (optional) | `On` |
 
 5. Click `Save`
 6. On the Credentials tab, update the fields:
@@ -167,7 +179,9 @@ oidc_providers:
       config:
         localpart_template: "{{ user.preferred_username }}"
         display_name_template: "{{ user.name }}"
+    backchannel_logout_enabled: true # Optional
 ```
+
 ### Auth0
 
 [Auth0][auth0] is a hosted SaaS IdP solution.
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 97fb505a5f..44358faf59 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -3021,6 +3021,15 @@ Options for each entry include:
      which is set to the claims returned by the UserInfo Endpoint and/or
      in the ID Token.
 
+* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications. 
+  Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`.
+  Defaults to `false`.
+
+* `backchannel_logout_ignore_sub`: by default, the OIDC Back-Channel Logout feature checks that the
+  `sub` claim matches the subject claim received during login. This check can be disabled by setting
+  this to `true`. Defaults to `false`.
+
+  You might want to disable this if the `subject_claim` returned by the mapping provider is not `sub`.
 
 It is possible to configure Synapse to only allow logins if certain attributes
 match particular values in the OIDC userinfo. The requirements can be listed under
diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py
index 5418a332da..0bd83f4010 100644
--- a/synapse/config/oidc.py
+++ b/synapse/config/oidc.py
@@ -123,6 +123,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
         "userinfo_endpoint": {"type": "string"},
         "jwks_uri": {"type": "string"},
         "skip_verification": {"type": "boolean"},
+        "backchannel_logout_enabled": {"type": "boolean"},
+        "backchannel_logout_ignore_sub": {"type": "boolean"},
         "user_profile_method": {
             "type": "string",
             "enum": ["auto", "userinfo_endpoint"],
@@ -292,6 +294,10 @@ def _parse_oidc_config_dict(
         token_endpoint=oidc_config.get("token_endpoint"),
         userinfo_endpoint=oidc_config.get("userinfo_endpoint"),
         jwks_uri=oidc_config.get("jwks_uri"),
+        backchannel_logout_enabled=oidc_config.get("backchannel_logout_enabled", False),
+        backchannel_logout_ignore_sub=oidc_config.get(
+            "backchannel_logout_ignore_sub", False
+        ),
         skip_verification=oidc_config.get("skip_verification", False),
         user_profile_method=oidc_config.get("user_profile_method", "auto"),
         allow_existing_users=oidc_config.get("allow_existing_users", False),
@@ -368,6 +374,12 @@ class OidcProviderConfig:
     # "openid" scope is used.
     jwks_uri: Optional[str]
 
+    # Whether Synapse should react to backchannel logouts
+    backchannel_logout_enabled: bool
+
+    # Whether Synapse should ignore the `sub` claim in backchannel logouts or not.
+    backchannel_logout_ignore_sub: bool
+
     # Whether to skip metadata verification
     skip_verification: bool
 
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 9759daf043..867973dcca 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -12,14 +12,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import binascii
 import inspect
+import json
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+)
 from urllib.parse import urlencode, urlparse
 
 import attr
+import unpaddedbase64
 from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken, jwt
+from authlib.jose import JsonWebToken, JWTClaims
+from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
 from authlib.oidc.core import CodeIDToken, UserInfo
@@ -35,9 +49,12 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 from twisted.web.http_headers import Headers
 
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
+from synapse.http.server import finish_request
+from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -88,6 +105,8 @@ class Token(TypedDict):
 #: there is no real point of doing this in our case.
 JWK = Dict[str, str]
 
+C = TypeVar("C")
+
 
 #: A JWK Set, as per RFC7517 sec 5.
 class JWKS(TypedDict):
@@ -247,6 +266,80 @@ class OidcHandler:
 
         await oidc_provider.handle_oidc_callback(request, session_data, code)
 
+    async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        This extracts the logout_token from the request and tries to figure out
+        which OpenID Provider it is comming from. This works by matching the iss claim
+        with the issuer and the aud claim with the client_id.
+
+        Since at this point we don't know who signed the JWT, we can't just
+        decode it using authlib since it will always verifies the signature. We
+        have to decode it manually without validating the signature. The actual JWT
+        verification is done in the `OidcProvider.handler_backchannel_logout` method,
+        once we figured out which provider sent the request.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+        logout_token = parse_string(request, "logout_token")
+        if logout_token is None:
+            raise SynapseError(400, "Missing logout_token in request")
+
+        # A JWT looks like this:
+        #    header.payload.signature
+        # where all parts are encoded with urlsafe base64.
+        # The aud and iss claims we care about are in the payload part, which
+        # is a JSON object.
+        try:
+            # By destructuring the list after splitting, we ensure that we have
+            # exactly 3 segments
+            _, payload, _ = logout_token.split(".")
+        except ValueError:
+            raise SynapseError(400, "Invalid logout_token in request")
+
+        try:
+            payload_bytes = unpaddedbase64.decode_base64(payload)
+            claims = json_decoder.decode(payload_bytes.decode("utf-8"))
+        except (json.JSONDecodeError, binascii.Error, UnicodeError):
+            raise SynapseError(400, "Invalid logout_token payload in request")
+
+        try:
+            # Let's extract the iss and aud claims
+            iss = claims["iss"]
+            aud = claims["aud"]
+            # The aud claim can be either a string or a list of string. Here we
+            # normalize it as a list of strings.
+            if isinstance(aud, str):
+                aud = [aud]
+
+            # Check that we have the right types for the aud and the iss claims
+            if not isinstance(iss, str) or not isinstance(aud, list):
+                raise TypeError()
+            for a in aud:
+                if not isinstance(a, str):
+                    raise TypeError()
+
+            # At this point we properly checked both claims types
+            issuer: str = iss
+            audience: List[str] = aud
+        except (TypeError, KeyError):
+            raise SynapseError(400, "Invalid issuer/audience in logout_token")
+
+        # Now that we know the audience and the issuer, we can figure out from
+        # what provider it is coming from
+        oidc_provider: Optional[OidcProvider] = None
+        for provider in self._providers.values():
+            if provider.issuer == issuer and provider.client_id in audience:
+                oidc_provider = provider
+                break
+
+        if oidc_provider is None:
+            raise SynapseError(400, "Could not find the OP that issued this event")
+
+        # Ask the provider to handle the logout request.
+        await oidc_provider.handle_backchannel_logout(request, logout_token)
+
 
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint"""
@@ -342,6 +435,7 @@ class OidcProvider:
         self.idp_brand = provider.idp_brand
 
         self._sso_handler = hs.get_sso_handler()
+        self._device_handler = hs.get_device_handler()
 
         self._sso_handler.register_identity_provider(self)
 
@@ -400,6 +494,41 @@ class OidcProvider:
             # If we're not using userinfo, we need a valid jwks to validate the ID token
             m.validate_jwks_uri()
 
+        if self._config.backchannel_logout_enabled:
+            if not m.get("backchannel_logout_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled for issuer %r"
+                    "but it does not advertise support for it",
+                    self.issuer,
+                )
+
+            elif not m.get("backchannel_logout_session_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled and supported "
+                    "by issuer %r but it might not send a session ID with "
+                    "logout tokens, which is required for the logouts to work",
+                    self.issuer,
+                )
+
+            if not self._config.backchannel_logout_ignore_sub:
+                # If OIDC backchannel logouts are enabled, the provider mapping provider
+                # should use the `sub` claim. We verify that by mapping a dumb user and
+                # see if we get back the sub claim
+                user = UserInfo({"sub": "thisisasubject"})
+                try:
+                    subject = self._user_mapping_provider.get_remote_user_id(user)
+                    if subject != user["sub"]:
+                        raise ValueError("Unexpected subject")
+                except Exception:
+                    logger.warning(
+                        f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
+                        "but it looks like the configured `user_mapping_provider` "
+                        "does not use the `sub` claim as subject. If it is the case, "
+                        "and you want Synapse to ignore the `sub` claim in OIDC "
+                        "Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
+                        "to `true` in the issuer config."
+                    )
+
     @property
     def _uses_userinfo(self) -> bool:
         """Returns True if the ``userinfo_endpoint`` should be used.
@@ -415,6 +544,16 @@ class OidcProvider:
             or self._user_profile_method == "userinfo_endpoint"
         )
 
+    @property
+    def issuer(self) -> str:
+        """The issuer identifying this provider."""
+        return self._config.issuer
+
+    @property
+    def client_id(self) -> str:
+        """The client_id used when interacting with this provider."""
+        return self._config.client_id
+
     async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
         """Return the provider metadata.
 
@@ -662,6 +801,59 @@ class OidcProvider:
 
         return UserInfo(resp)
 
+    async def _verify_jwt(
+        self,
+        alg_values: List[str],
+        token: str,
+        claims_cls: Type[C],
+        claims_options: Optional[dict] = None,
+        claims_params: Optional[dict] = None,
+    ) -> C:
+        """Decode and validate a JWT, re-fetching the JWKS as needed.
+
+        Args:
+            alg_values: list of `alg` values allowed when verifying the JWT.
+            token: the JWT.
+            claims_cls: the JWTClaims class to use to validate the claims.
+            claims_options: dict of options passed to the `claims_cls` constructor.
+            claims_params: dict of params passed to the `claims_cls` constructor.
+
+        Returns:
+            The decoded claims in the JWT.
+        """
+        jwt = JsonWebToken(alg_values)
+
+        logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
+
+        # Try to decode the keys in cache first, then retry by forcing the keys
+        # to be reloaded
+        jwk_set = await self.load_jwks()
+        try:
+            claims = jwt.decode(
+                token,
+                key=jwk_set,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
+                claims_params=claims_params,
+            )
+        except ValueError:
+            logger.info("Reloading JWKS after decode error")
+            jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
+            claims = jwt.decode(
+                token,
+                key=jwk_set,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
+                claims_params=claims_params,
+            )
+
+        logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
+
+        claims.validate(
+            now=self._clock.time(), leeway=120
+        )  # allows 2 min of clock skew
+        return claims
+
     async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
         """Return an instance of UserInfo from token's ``id_token``.
 
@@ -675,13 +867,13 @@ class OidcProvider:
             The decoded claims in the ID token.
         """
         id_token = token.get("id_token")
-        logger.debug("Attempting to decode JWT id_token %r", id_token)
 
         # That has been theoritically been checked by the caller, so even though
         # assertion are not enabled in production, it is mainly here to appease mypy
         assert id_token is not None
 
         metadata = await self.load_metadata()
+
         claims_params = {
             "nonce": nonce,
             "client_id": self._client_auth.client_id,
@@ -691,38 +883,17 @@ class OidcProvider:
             # in the `id_token` that we can check against.
             claims_params["access_token"] = token["access_token"]
 
-        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
-        jwt = JsonWebToken(alg_values)
-
-        claim_options = {"iss": {"values": [metadata["issuer"]]}}
+        claims_options = {"iss": {"values": [metadata["issuer"]]}}
 
-        # Try to decode the keys in cache first, then retry by forcing the keys
-        # to be reloaded
-        jwk_set = await self.load_jwks()
-        try:
-            claims = jwt.decode(
-                id_token,
-                key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
-                claims_params=claims_params,
-            )
-        except ValueError:
-            logger.info("Reloading JWKS after decode error")
-            jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
-            claims = jwt.decode(
-                id_token,
-                key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
-                claims_params=claims_params,
-            )
-
-        logger.debug("Decoded id_token JWT %r; validating", claims)
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
 
-        claims.validate(
-            now=self._clock.time(), leeway=120
-        )  # allows 2 min of clock skew
+        claims = await self._verify_jwt(
+            alg_values=alg_values,
+            token=id_token,
+            claims_cls=CodeIDToken,
+            claims_options=claims_options,
+            claims_params=claims_params,
+        )
 
         return claims
 
@@ -1043,6 +1214,146 @@ class OidcProvider:
         # to be strings.
         return str(remote_user_id)
 
+    async def handle_backchannel_logout(
+        self, request: SynapseRequest, logout_token: str
+    ) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        The OIDC Provider posts a logout token to this endpoint when a user
+        session ends. That token is a JWT signed with the same keys as
+        ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
+        validate the JWT and figure out what session to end.
+
+        Args:
+            request: The request to respond to
+            logout_token: The logout token (a JWT) extracted from the request body
+        """
+        # Back-Channel Logout can be disabled in the config, hence this check.
+        # This is not that important for now since Synapse is registered
+        # manually to the OP, so not specifying the backchannel-logout URI is
+        # as effective than disabling it here. It might make more sense if we
+        # support dynamic registration in Synapse at some point.
+        if not self._config.backchannel_logout_enabled:
+            logger.warning(
+                f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
+            )
+
+            # TODO: this responds with a 400 status code, which is what the OIDC
+            # Back-Channel Logout spec expects, but spec also suggests answering with
+            # a JSON object, with the `error` and `error_description` fields set, which
+            # we are not doing here.
+            # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
+            raise SynapseError(
+                400, "OpenID Connect Back-Channel Logout is disabled for this provider"
+            )
+
+        metadata = await self.load_metadata()
+
+        # As per OIDC Back-Channel Logout 1.0 sec. 2.4:
+        #   A Logout Token MUST be signed and MAY also be encrypted. The same
+        #   keys are used to sign and encrypt Logout Tokens as are used for ID
+        #   Tokens. If the Logout Token is encrypted, it SHOULD replicate the
+        #   iss (issuer) claim in the JWT Header Parameters, as specified in
+        #   Section 5.3 of [JWT].
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+        # As per sec. 2.6:
+        #    3. Validate the iss, aud, and iat Claims in the same way they are
+        #       validated in ID Tokens.
+        # Which means the audience should contain Synapse's client_id and the
+        # issuer should be the IdP issuer
+        claims_options = {
+            "iss": {"values": [metadata["issuer"]]},
+            "aud": {"values": [self.client_id]},
+        }
+
+        try:
+            claims = await self._verify_jwt(
+                alg_values=alg_values,
+                token=logout_token,
+                claims_cls=LogoutToken,
+                claims_options=claims_options,
+            )
+        except JoseError:
+            logger.exception("Invalid logout_token")
+            raise SynapseError(400, "Invalid logout_token")
+
+        # As per sec. 2.6:
+        #    4. Verify that the Logout Token contains a sub Claim, a sid Claim,
+        #       or both.
+        #    5. Verify that the Logout Token contains an events Claim whose
+        #       value is JSON object containing the member name
+        #       http://schemas.openid.net/event/backchannel-logout.
+        #    6. Verify that the Logout Token does not contain a nonce Claim.
+        # This is all verified by the LogoutToken claims class, so at this
+        # point the `sid` claim exists and is a string.
+        sid: str = claims.get("sid")
+
+        # If the `sub` claim was included in the logout token, we check that it matches
+        # that it matches the right user. We can have cases where the `sub` claim is not
+        # the ID saved in database, so we let admins disable this check in config.
+        sub: Optional[str] = claims.get("sub")
+        expected_user_id: Optional[str] = None
+        if sub is not None and not self._config.backchannel_logout_ignore_sub:
+            expected_user_id = await self._store.get_user_by_external_id(
+                self.idp_id, sub
+            )
+
+        # Invalidate any running user-mapping sessions, in-flight login tokens and
+        # active devices
+        await self._sso_handler.revoke_sessions_for_provider_session_id(
+            auth_provider_id=self.idp_id,
+            auth_provider_session_id=sid,
+            expected_user_id=expected_user_id,
+        )
+
+        request.setResponseCode(200)
+        request.setHeader(b"Cache-Control", b"no-cache, no-store")
+        request.setHeader(b"Pragma", b"no-cache")
+        finish_request(request)
+
+
+class LogoutToken(JWTClaims):
+    """
+    Holds and verify claims of a logout token, as per
+    https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
+    """
+
+    REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
+
+    def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
+        """Validate everything in claims payload."""
+        super().validate(now, leeway)
+        self.validate_sid()
+        self.validate_events()
+        self.validate_nonce()
+
+    def validate_sid(self) -> None:
+        """Ensure the sid claim is present"""
+        sid = self.get("sid")
+        if not sid:
+            raise MissingClaimError("sid")
+
+        if not isinstance(sid, str):
+            raise InvalidClaimError("sid")
+
+    def validate_nonce(self) -> None:
+        """Ensure the nonce claim is absent"""
+        if "nonce" in self:
+            raise InvalidClaimError("nonce")
+
+    def validate_events(self) -> None:
+        """Ensure the events claim is present and with the right value"""
+        events = self.get("events")
+        if not events:
+            raise MissingClaimError("events")
+
+        if not isinstance(events, dict):
+            raise InvalidClaimError("events")
+
+        if "http://schemas.openid.net/event/backchannel-logout" not in events:
+            raise InvalidClaimError("events")
+
 
 # number of seconds a newly-generated client secret should be valid for
 CLIENT_SECRET_VALIDITY_SECONDS = 3600
@@ -1112,6 +1423,7 @@ class JwtClientSecret:
         logger.info(
             "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
         )
+        jwt = JsonWebToken(header["alg"])
         self._cached_secret = jwt.encode(header, payload, self._key.key)
         self._cached_secret_replacement_time = (
             expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
@@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict):
     emails: List[str]
 
 
-C = TypeVar("C")
-
-
 class OidcMappingProvider(Generic[C]):
     """A mapping provider maps a UserInfo object to user attributes.
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 5943f08e91..749d7e93b0 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -191,6 +191,7 @@ class SsoHandler:
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
         self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
         self._error_template = hs.config.sso.sso_error_template
         self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
         self._profile_handler = hs.get_profile_handler()
@@ -1026,6 +1027,76 @@ class SsoHandler:
 
         return True
 
+    async def revoke_sessions_for_provider_session_id(
+        self,
+        auth_provider_id: str,
+        auth_provider_session_id: str,
+        expected_user_id: Optional[str] = None,
+    ) -> None:
+        """Revoke any devices and in-flight logins tied to a provider session.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            auth_provider_session_id: The session ID from the provider to logout
+            expected_user_id: The user we're expecting to logout. If set, it will ignore
+                sessions belonging to other users and log an error.
+        """
+        # Invalidate any running user-mapping sessions
+        to_delete = []
+        for session_id, session in self._username_mapping_sessions.items():
+            if (
+                session.auth_provider_id == auth_provider_id
+                and session.auth_provider_session_id == auth_provider_session_id
+            ):
+                to_delete.append(session_id)
+
+        for session_id in to_delete:
+            logger.info("Revoking mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]
+
+        # Invalidate any in-flight login tokens
+        await self._store.invalidate_login_tokens_by_session_id(
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+        # Fetch any device(s) in the store associated with the session ID.
+        devices = await self._store.get_devices_by_auth_provider_session_id(
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+        # We have no guarantee that all the devices of that session are for the same
+        # `user_id`. Hence, we have to iterate over the list of devices and log them out
+        # one by one.
+        for device in devices:
+            user_id = device["user_id"]
+            device_id = device["device_id"]
+
+            # If the user_id associated with that device/session is not the one we got
+            # out of the `sub` claim, skip that device and show log an error.
+            if expected_user_id is not None and user_id != expected_user_id:
+                logger.error(
+                    "Received a logout notification from SSO provider "
+                    f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
+                    f"a session ID ({auth_provider_session_id!r}) which belongs to "
+                    f"{user_id!r}. This may happen when the SSO provider user mapper "
+                    "uses something else than the standard attribute as mapping ID. "
+                    "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
+                    "in the provider config if that is the case."
+                )
+                continue
+
+            logger.info(
+                "Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
+                user_id,
+                device_id,
+                auth_provider_id,
+                auth_provider_session_id,
+            )
+            await self._device_handler.delete_devices(user_id, [device_id])
+
 
 def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     """Extract the session ID from the cookie
diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index 81fec39659..e4b28ce3df 100644
--- a/synapse/rest/synapse/client/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -17,6 +17,9 @@ from typing import TYPE_CHECKING
 
 from twisted.web.resource import Resource
 
+from synapse.rest.synapse.client.oidc.backchannel_logout_resource import (
+    OIDCBackchannelLogoutResource,
+)
 from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
 
 if TYPE_CHECKING:
@@ -29,6 +32,7 @@ class OIDCResource(Resource):
     def __init__(self, hs: "HomeServer"):
         Resource.__init__(self)
         self.putChild(b"callback", OIDCCallbackResource(hs))
+        self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs))
 
 
 __all__ = ["OIDCResource"]
diff --git a/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py
new file mode 100644
index 0000000000..e07e76855a
--- /dev/null
+++ b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py
@@ -0,0 +1,35 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import DirectServeJsonResource
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCBackchannelLogoutResource(DirectServeJsonResource):
+    isLeaf = 1
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._oidc_handler = hs.get_oidc_handler()
+
+    async def _async_render_POST(self, request: SynapseRequest) -> None:
+        await self._oidc_handler.handle_backchannel_logout(request)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 0255295317..5167089e03 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1920,6 +1920,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             self._clock.time_msec(),
         )
 
+    async def invalidate_login_tokens_by_session_id(
+        self, auth_provider_id: str, auth_provider_session_id: str
+    ) -> None:
+        """Invalidate login tokens with the given IdP session ID.
+
+        Args:
+            auth_provider_id: The SSO Identity Provider that the user authenticated with
+                to get this token
+            auth_provider_session_id: The session ID advertised by the SSO Identity
+                Provider
+        """
+        await self.db_pool.simple_update(
+            table="login_tokens",
+            keyvalues={
+                "auth_provider_id": auth_provider_id,
+                "auth_provider_session_id": auth_provider_session_id,
+            },
+            updatevalues={"used_ts": self._clock.time_msec()},
+            desc="invalidate_login_tokens_by_session_id",
+        )
+
     @cached()
     async def is_guest(self, user_id: str) -> bool:
         res = await self.db_pool.simple_select_one_onecol(
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index ebf653d018..847294dc8e 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import re
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Tuple, Union
 
@@ -21,7 +22,7 @@ from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, SynapseError
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -32,8 +33,8 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.utils import TEST_OIDC_CONFIG
-from tests.server import FakeChannel
+from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
+from tests.server import FakeChannel, make_request
 from tests.unittest import override_config, skip_unless
 
 
@@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": refresh_token},
         )
 
-    def is_access_token_valid(self, access_token: str) -> bool:
-        """
-        Checks whether an access token is valid, returning whether it is or not.
-        """
-        code = self.make_request(
-            "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
-        ).code
-
-        # Either 200 or 401 is what we get back; anything else is a bug.
-        assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
-
-        return code == HTTPStatus.OK
-
     def test_login_issue_refresh_token(self) -> None:
         """
         A login response should include a refresh_token only if asked.
@@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.reactor.advance(59.0)
 
         # Both tokens should still be valid.
-        self.assertTrue(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 61 s (just past 1 minute, the time of expiry)
         self.reactor.advance(2.0)
 
         # Only the non-refreshable token is still valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 599 s (just shy of 10 minutes, the time of expiry)
         self.reactor.advance(599.0 - 61.0)
 
         # It's still the case that only the non-refreshable token is still valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 601 s (just past 10 minutes, the time of expiry)
         self.reactor.advance(2.0)
 
         # Now neither token is valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(
+            nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
 
     @override_config(
         {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
@@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         # and no refresh token
         self.assertEqual(_table_length("access_tokens"), 0)
         self.assertEqual(_table_length("refresh_tokens"), 0)
+
+
+def oidc_config(
+    id: str, with_localpart_template: bool, **kwargs: Any
+) -> Dict[str, Any]:
+    """Sample OIDC provider config used in backchannel logout tests.
+
+    Args:
+        id: IDP ID for this provider
+        with_localpart_template: Set to `true` to have a default localpart_template in
+            the `user_mapping_provider` config and skip the user mapping session
+        **kwargs: rest of the config
+
+    Returns:
+        A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
+        the HS config
+    """
+    config: Dict[str, Any] = {
+        "idp_id": id,
+        "idp_name": id,
+        "issuer": TEST_OIDC_ISSUER,
+        "client_id": "test-client-id",
+        "client_secret": "test-client-secret",
+        "scopes": ["openid"],
+    }
+
+    if with_localpart_template:
+        config["user_mapping_provider"] = {
+            "config": {"localpart_template": "{{ user.sub }}"}
+        }
+    else:
+        config["user_mapping_provider"] = {"config": {}}
+
+    config.update(kwargs)
+
+    return config
+
+
+@skip_unless(HAS_OIDC, "Requires OIDC")
+class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
+    servlets = [
+        account.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+
+        # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+        # False, so synapse will see the requested uri as http://..., so using http in
+        # the public_baseurl stops Synapse trying to redirect to https.
+        config["public_baseurl"] = "http://synapse.test"
+
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resource_dict = super().create_resource_dict()
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
+        return resource_dict
+
+    def submit_logout_token(self, logout_token: str) -> FakeChannel:
+        return self.make_request(
+            "POST",
+            "/_synapse/client/oidc/backchannel_logout",
+            content=f"logout_token={logout_token}",
+            content_is_form=True,
+        )
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_simple_logout(self) -> None:
+        """
+        Receiving a logout token should logout the user
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, first_grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        first_access_token: str = login_resp["access_token"]
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+        login_resp, second_grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        second_access_token: str = login_resp["access_token"]
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        self.assertNotEqual(first_grant.sid, second_grant.sid)
+        self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+        # Logging out of the first session
+        logout_token = fake_oidc_server.generate_logout_token(first_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out of the second session
+        logout_token = fake_oidc_server.generate_logout_token(second_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_logout_during_login(self) -> None:
+        """
+        It should revoke login tokens when receiving a logout token
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        # Get an authentication, and logout before submitting the logout token
+        client_redirect_url = "https://x"
+        userinfo = {"sub": user}
+        channel, grant = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=True,
+        )
+
+        # expect a confirmation page
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.text_body,
+        )
+        assert m, channel.text_body
+        login_token = m.group(1)
+
+        # Submit a logout
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        # Now try to exchange the login token
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        # It should have failed
+        self.assertEqual(channel.code, 403)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=False,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_logout_during_mapping(self) -> None:
+        """
+        It should stop ongoing user mapping session when receiving a logout token
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        # Get an authentication, and logout before submitting the logout token
+        client_redirect_url = "https://x"
+        userinfo = {"sub": user}
+        channel, grant = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=True,
+        )
+
+        # Expect a user mapping page
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+
+        # We should have a user_mapping_session cookie
+        cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+        assert cookie_headers
+        cookies: Dict[str, str] = {}
+        for h in cookie_headers:
+            key, value = h.split(";")[0].split("=", maxsplit=1)
+            cookies[key] = value
+
+        user_mapping_session_id = cookies["username_mapping_session"]
+
+        # Getting that session should not raise
+        session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+        self.assertIsNotNone(session)
+
+        # Submit a logout
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        # Now it should raise
+        with self.assertRaises(SynapseError):
+            self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=False,
+                )
+            ]
+        }
+    )
+    def test_disabled(self) -> None:
+        """
+        Receiving a logout token should do nothing if it is disabled in the config
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        access_token: str = login_resp["access_token"]
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out shouldn't work
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 400)
+
+        # And the token should still be valid
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_no_sid(self) -> None:
+        """
+        Receiving a logout token without `sid` during the login should do nothing
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=False
+        )
+        access_token: str = login_resp["access_token"]
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out shouldn't work
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 400)
+
+        # And the token should still be valid
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    "first",
+                    issuer="https://first-issuer.com/",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                ),
+                oidc_config(
+                    "second",
+                    issuer="https://second-issuer.com/",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                ),
+            ]
+        }
+    )
+    def test_multiple_providers(self) -> None:
+        """
+        It should be able to distinguish login tokens from two different IdPs
+        """
+        first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
+        second_server = self.helper.fake_oidc_server(
+            issuer="https://second-issuer.com/"
+        )
+        user = "john"
+
+        login_resp, first_grant = self.helper.login_via_oidc(
+            first_server, user, with_sid=True, idp_id="oidc-first"
+        )
+        first_access_token: str = login_resp["access_token"]
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+        login_resp, second_grant = self.helper.login_via_oidc(
+            second_server, user, with_sid=True, idp_id="oidc-second"
+        )
+        second_access_token: str = login_resp["access_token"]
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # `sid` in the fake providers are generated by a counter, so the first grant of
+        # each provider should give the same SID
+        self.assertEqual(first_grant.sid, second_grant.sid)
+        self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+        # Logging out of the first session
+        logout_token = first_server.generate_logout_token(first_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out of the second session
+        logout_token = second_server.generate_logout_token(second_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 967d229223..706399fae5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -553,6 +553,34 @@ class RestHelper:
 
         return channel.json_body
 
+    def whoami(
+        self,
+        access_token: str,
+        expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
+    ) -> JsonDict:
+        """Perform a 'whoami' request, which can be a quick way to check for access
+        token validity
+
+        Args:
+            access_token: The user token to use during the request
+            expect_code: The return code to expect from attempting the whoami request
+        """
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            "account/whoami",
+            access_token=access_token,
+        )
+
+        assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % (
+            expect_code,
+            channel.code,
+            channel.result["body"],
+        )
+
+        return channel.json_body
+
     def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
         """Create a ``FakeOidcServer``.
 
@@ -572,6 +600,7 @@ class RestHelper:
         fake_server: FakeOidcServer,
         remote_user_id: str,
         with_sid: bool = False,
+        idp_id: Optional[str] = None,
         expected_status: int = 200,
     ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
         """Log in (as a new user) via OIDC
@@ -588,7 +617,11 @@ class RestHelper:
         client_redirect_url = "https://x"
         userinfo = {"sub": remote_user_id}
         channel, grant = self.auth_via_oidc(
-            fake_server, userinfo, client_redirect_url, with_sid=with_sid
+            fake_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=with_sid,
+            idp_id=idp_id,
         )
 
         # expect a confirmation page
@@ -623,6 +656,7 @@ class RestHelper:
         client_redirect_url: Optional[str] = None,
         ui_auth_session_id: Optional[str] = None,
         with_sid: bool = False,
+        idp_id: Optional[str] = None,
     ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Perform an OIDC authentication flow via a mock OIDC provider.
 
@@ -648,6 +682,7 @@ class RestHelper:
             ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
                 of the UI auth.
             with_sid: if True, generates a random `sid` (OIDC session ID)
+            idp_id: if set, explicitely chooses one specific IDP
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -665,7 +700,9 @@ class RestHelper:
                 oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
             else:
                 # otherwise, hit the login redirect endpoint
-                oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+                oauth_uri = self.initiate_sso_login(
+                    client_redirect_url, cookies, idp_id=idp_id
+                )
 
         # we now have a URI for the OIDC IdP, but we skip that and go straight
         # back to synapse's OIDC callback resource. However, we do need the "state"
@@ -742,7 +779,10 @@ class RestHelper:
         return channel, grant
 
     def initiate_sso_login(
-        self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+        self,
+        client_redirect_url: Optional[str],
+        cookies: MutableMapping[str, str],
+        idp_id: Optional[str] = None,
     ) -> str:
         """Make a request to the login-via-sso redirect endpoint, and return the target
 
@@ -753,6 +793,7 @@ class RestHelper:
             client_redirect_url: the client redirect URL to pass to the login redirect
                 endpoint
             cookies: any cookies returned will be added to this dict
+            idp_id: if set, explicitely chooses one specific IDP
 
         Returns:
             the URI that the client gets redirected to (ie, the SSO server)
@@ -761,6 +802,12 @@ class RestHelper:
         if client_redirect_url:
             params["redirectUrl"] = client_redirect_url
 
+        uri = "/_matrix/client/r0/login/sso/redirect"
+        if idp_id is not None:
+            uri = f"{uri}/{idp_id}"
+
+        uri = f"{uri}?{urllib.parse.urlencode(params)}"
+
         # hit the redirect url (which should redirect back to the redirect url. This
         # is the easiest way of figuring out what the Host header ought to be set to
         # to keep Synapse happy.
@@ -768,7 +815,7 @@ class RestHelper:
             self.hs.get_reactor(),
             self.site,
             "GET",
-            "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+            uri,
         )
         assert channel.code == 302
 
diff --git a/tests/server.py b/tests/server.py
index 8b1d186219..b1730fcc8d 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -362,6 +362,12 @@ def make_request(
     # Twisted expects to be at the end of the content when parsing the request.
     req.content.seek(0, SEEK_END)
 
+    # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
+    # bodies if the Content-Length header is missing
+    req.requestHeaders.addRawHeader(
+        b"Content-Length", str(len(content)).encode("ascii")
+    )
+
     if access_token:
         req.requestHeaders.addRawHeader(
             b"Authorization", b"Bearer " + access_token.encode("ascii")
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index de134bbc89..1461d23ee8 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -51,6 +51,8 @@ class FakeOidcServer:
     get_userinfo_handler: Mock
     post_token_handler: Mock
 
+    sid_counter: int = 0
+
     def __init__(self, clock: Clock, issuer: str):
         from authlib.jose import ECKey, KeySet
 
@@ -146,7 +148,7 @@ class FakeOidcServer:
         return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
 
     def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
-        now = self._clock.time()
+        now = int(self._clock.time())
         id_token = {
             **grant.userinfo,
             "iss": self.issuer,
@@ -166,6 +168,26 @@ class FakeOidcServer:
 
         return self._sign(id_token)
 
+    def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
+        now = int(self._clock.time())
+        logout_token = {
+            "iss": self.issuer,
+            "aud": grant.client_id,
+            "iat": now,
+            "jti": random_string(10),
+            "events": {
+                "http://schemas.openid.net/event/backchannel-logout": {},
+            },
+        }
+
+        if grant.sid is not None:
+            logout_token["sid"] = grant.sid
+
+        if "sub" in grant.userinfo:
+            logout_token["sub"] = grant.userinfo["sub"]
+
+        return self._sign(logout_token)
+
     def id_token_override(self, overrides: dict):
         """Temporarily patch the ID token generated by the token endpoint."""
         return patch.object(self, "_id_token_overrides", overrides)
@@ -183,7 +205,8 @@ class FakeOidcServer:
         code = random_string(10)
         sid = None
         if with_sid:
-            sid = random_string(10)
+            sid = str(self.sid_counter)
+            self.sid_counter += 1
 
         grant = FakeAuthorizationGrant(
             userinfo=userinfo,