diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 5943f08e91..749d7e93b0 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -191,6 +191,7 @@ class SsoHandler:
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
self._error_template = hs.config.sso.sso_error_template
self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
self._profile_handler = hs.get_profile_handler()
@@ -1026,6 +1027,76 @@ class SsoHandler:
return True
+ async def revoke_sessions_for_provider_session_id(
+ self,
+ auth_provider_id: str,
+ auth_provider_session_id: str,
+ expected_user_id: Optional[str] = None,
+ ) -> None:
+ """Revoke any devices and in-flight logins tied to a provider session.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ auth_provider_session_id: The session ID from the provider to logout
+ expected_user_id: The user we're expecting to logout. If set, it will ignore
+ sessions belonging to other users and log an error.
+ """
+ # Invalidate any running user-mapping sessions
+ to_delete = []
+ for session_id, session in self._username_mapping_sessions.items():
+ if (
+ session.auth_provider_id == auth_provider_id
+ and session.auth_provider_session_id == auth_provider_session_id
+ ):
+ to_delete.append(session_id)
+
+ for session_id in to_delete:
+ logger.info("Revoking mapping session %s", session_id)
+ del self._username_mapping_sessions[session_id]
+
+ # Invalidate any in-flight login tokens
+ await self._store.invalidate_login_tokens_by_session_id(
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ # Fetch any device(s) in the store associated with the session ID.
+ devices = await self._store.get_devices_by_auth_provider_session_id(
+ auth_provider_id=auth_provider_id,
+ auth_provider_session_id=auth_provider_session_id,
+ )
+
+ # We have no guarantee that all the devices of that session are for the same
+ # `user_id`. Hence, we have to iterate over the list of devices and log them out
+ # one by one.
+ for device in devices:
+ user_id = device["user_id"]
+ device_id = device["device_id"]
+
+ # If the user_id associated with that device/session is not the one we got
+ # out of the `sub` claim, skip that device and show log an error.
+ if expected_user_id is not None and user_id != expected_user_id:
+ logger.error(
+ "Received a logout notification from SSO provider "
+ f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
+ f"a session ID ({auth_provider_session_id!r}) which belongs to "
+ f"{user_id!r}. This may happen when the SSO provider user mapper "
+ "uses something else than the standard attribute as mapping ID. "
+ "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
+ "in the provider config if that is the case."
+ )
+ continue
+
+ logger.info(
+ "Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
+ user_id,
+ device_id,
+ auth_provider_id,
+ auth_provider_session_id,
+ )
+ await self._device_handler.delete_devices(user_id, [device_id])
+
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
|